or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

client.mdconfiguration.mddata.mdframeworks.mdgenai.mdindex.mdmodels.mdprojects.mdtracing.mdtracking.md

frameworks.mddocs/

0

# ML Framework Integrations

1

2

MLflow provides comprehensive integrations with popular machine learning and deep learning frameworks, enabling seamless model logging, loading, and deployment across different ML ecosystems. Each integration offers framework-specific optimizations and native model format support.

3

4

## Capabilities

5

6

### Scikit-learn Integration

7

8

Native integration for scikit-learn models with automatic dependency management and preprocessing pipeline support.

9

10

```python { .api }

11

import mlflow.sklearn

12

13

def log_model(sk_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, serialization_format=SERIALIZATION_FORMAT_PICKLE, metadata=None, **kwargs):

14

"""

15

Log scikit-learn model as MLflow artifact.

16

17

Parameters:

18

- sk_model: Trained scikit-learn model object

19

- artifact_path: str - Run-relative artifact path

20

- conda_env: str or dict, optional - Conda environment specification

21

- code_paths: list, optional - List of local code paths to include

22

- registered_model_name: str, optional - Name for model registry

23

- signature: ModelSignature, optional - Model input/output schema

24

- input_example: Any, optional - Example input for inference

25

- await_registration_for: int - Seconds to wait for registration

26

- pip_requirements: list, optional - List of pip package requirements

27

- extra_pip_requirements: list, optional - Additional pip requirements

28

- serialization_format: str - Serialization format (pickle, cloudpickle)

29

- metadata: dict, optional - Custom model metadata

30

31

Returns:

32

ModelInfo object with logged model details

33

"""

34

35

def load_model(model_uri, dst_path=None):

36

"""

37

Load scikit-learn model from MLflow.

38

39

Parameters:

40

- model_uri: str - URI pointing to MLflow model

41

- dst_path: str, optional - Local destination path

42

43

Returns:

44

Loaded scikit-learn model object

45

"""

46

47

def save_model(sk_model, path, conda_env=None, code_paths=None, mlflow_model=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, serialization_format=SERIALIZATION_FORMAT_PICKLE, metadata=None):

48

"""

49

Save scikit-learn model to local path.

50

51

Parameters:

52

- sk_model: Trained scikit-learn model object

53

- path: str - Local path to save model

54

- conda_env: str or dict, optional - Conda environment

55

- code_paths: list, optional - Code dependencies to include

56

- mlflow_model: Model, optional - MLflow model configuration

57

- signature: ModelSignature, optional - Model signature

58

- input_example: Any, optional - Example input

59

- pip_requirements: list, optional - Pip package requirements

60

- extra_pip_requirements: list, optional - Additional pip requirements

61

- serialization_format: str - Serialization format

62

- metadata: dict, optional - Custom metadata

63

"""

64

```

65

66

### PyTorch Integration

67

68

Comprehensive PyTorch support including standard models, PyTorch Lightning, and TorchScript compilation.

69

70

```python { .api }

71

import mlflow.pytorch

72

73

def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, requirements_file=None, extra_files=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):

74

"""

75

Log PyTorch model as MLflow artifact.

76

77

Parameters:

78

- pytorch_model: PyTorch model object or state_dict

79

- artifact_path: str - Run-relative artifact path

80

- conda_env: str or dict, optional - Conda environment

81

- code_paths: list, optional - Local code paths to include

82

- pickle_module: module, optional - Module for model serialization

83

- registered_model_name: str, optional - Registry model name

84

- signature: ModelSignature, optional - Model schema

85

- input_example: Any, optional - Example model input

86

- await_registration_for: int - Registration wait time

87

- requirements_file: str, optional - Path to requirements file

88

- extra_files: list, optional - Additional files to include

89

- pip_requirements: list, optional - Pip requirements

90

- extra_pip_requirements: list, optional - Additional pip requirements

91

- metadata: dict, optional - Custom metadata

92

93

Returns:

94

ModelInfo object

95

"""

96

97

def load_model(model_uri, map_location=None, dst_path=None):

98

"""

99

Load PyTorch model from MLflow.

100

101

Parameters:

102

- model_uri: str - URI pointing to MLflow model

103

- map_location: str or torch.device, optional - Device mapping for loading

104

- dst_path: str, optional - Local destination path

105

106

Returns:

107

Loaded PyTorch model object

108

"""

109

110

def log_state_dict(state_dict, artifact_path, **kwargs):

111

"""

112

Log PyTorch model state dictionary.

113

114

Parameters:

115

- state_dict: dict - PyTorch model state dictionary

116

- artifact_path: str - Artifact path for state dict

117

- kwargs: Additional logging arguments

118

"""

119

120

def load_state_dict(model_uri, map_location=None):

121

"""

122

Load PyTorch state dictionary from MLflow.

123

124

Parameters:

125

- model_uri: str - URI pointing to saved state dict

126

- map_location: str or device, optional - Device for loading

127

128

Returns:

129

PyTorch state dictionary

130

"""

131

```

132

133

### TensorFlow Integration

134

135

Full TensorFlow support including Keras models, SavedModel format, and TensorFlow Serving compatibility.

136

137

```python { .api }

138

import mlflow.tensorflow

139

140

def log_model(tf_saved_model_dir=None, tf_meta_graph_tags=None, tf_signature_def_key=None, artifact_path=None, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):

141

"""

142

Log TensorFlow model as MLflow artifact.

143

144

Parameters:

145

- tf_saved_model_dir: str - Path to TensorFlow SavedModel directory

146

- tf_meta_graph_tags: list, optional - MetaGraph tags to load

147

- tf_signature_def_key: str, optional - SignatureDef key for inference

148

- artifact_path: str - Run-relative artifact path

149

- conda_env: str or dict, optional - Conda environment

150

- code_paths: list, optional - Code dependencies

151

- registered_model_name: str, optional - Registry model name

152

- signature: ModelSignature, optional - Model schema

153

- input_example: Any, optional - Example input

154

- await_registration_for: int - Registration wait time

155

- pip_requirements: list, optional - Pip requirements

156

- extra_pip_requirements: list, optional - Additional pip requirements

157

- metadata: dict, optional - Custom metadata

158

159

Returns:

160

ModelInfo object

161

"""

162

163

def load_model(model_uri, dst_path=None):

164

"""

165

Load TensorFlow model from MLflow.

166

167

Parameters:

168

- model_uri: str - URI pointing to MLflow model

169

- dst_path: str, optional - Local destination path

170

171

Returns:

172

Loaded TensorFlow model object

173

"""

174

175

import mlflow.keras

176

177

def log_model(keras_model, artifact_path, conda_env=None, code_paths=None, custom_objects=None, keras_module=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):

178

"""

179

Log Keras model as MLflow artifact.

180

181

Parameters:

182

- keras_model: Compiled Keras model object

183

- artifact_path: str - Run-relative artifact path

184

- conda_env: str or dict, optional - Conda environment

185

- code_paths: list, optional - Code dependencies

186

- custom_objects: dict, optional - Custom objects for model loading

187

- keras_module: module, optional - Keras module for compatibility

188

- registered_model_name: str, optional - Registry model name

189

- signature: ModelSignature, optional - Model schema

190

- input_example: Any, optional - Example input

191

- await_registration_for: int - Registration wait time

192

- pip_requirements: list, optional - Pip requirements

193

- extra_pip_requirements: list, optional - Additional pip requirements

194

- metadata: dict, optional - Custom metadata

195

196

Returns:

197

ModelInfo object

198

"""

199

```

200

201

### XGBoost Integration

202

203

Native XGBoost model support with automatic hyperparameter tracking and feature importance logging.

204

205

```python { .api }

206

import mlflow.xgboost

207

208

def log_model(xgb_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, model_format="xgb", metadata=None, **kwargs):

209

"""

210

Log XGBoost model as MLflow artifact.

211

212

Parameters:

213

- xgb_model: Trained XGBoost model (Booster, XGBClassifier, XGBRegressor)

214

- artifact_path: str - Run-relative artifact path

215

- conda_env: str or dict, optional - Conda environment

216

- code_paths: list, optional - Code dependencies

217

- registered_model_name: str, optional - Registry model name

218

- signature: ModelSignature, optional - Model schema

219

- input_example: Any, optional - Example input

220

- await_registration_for: int - Registration wait time

221

- pip_requirements: list, optional - Pip requirements

222

- extra_pip_requirements: list, optional - Additional requirements

223

- model_format: str - Save format ("xgb", "json", "ubj")

224

- metadata: dict, optional - Custom metadata

225

226

Returns:

227

ModelInfo object

228

"""

229

230

def load_model(model_uri, dst_path=None):

231

"""

232

Load XGBoost model from MLflow.

233

234

Parameters:

235

- model_uri: str - URI pointing to MLflow model

236

- dst_path: str, optional - Local destination path

237

238

Returns:

239

Loaded XGBoost model object

240

"""

241

242

def autolog(importance_type="weight", log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None):

243

"""

244

Enable automatic logging for XGBoost training.

245

246

Parameters:

247

- importance_type: str - Feature importance type to log

248

- log_input_examples: bool - Whether to log input examples

249

- log_model_signatures: bool - Whether to log model signatures

250

- log_models: bool - Whether to log trained models

251

- disable: bool - Disable autologging if True

252

- exclusive: bool - Exclusive autologging mode

253

- disable_for_unsupported_versions: bool - Skip for unsupported versions

254

- silent: bool - Suppress autolog warnings

255

- registered_model_name: str, optional - Auto-register model name

256

"""

257

```

258

259

### LightGBM Integration

260

261

Comprehensive LightGBM support with early stopping integration and automatic metric logging.

262

263

```python { .api }

264

import mlflow.lightgbm

265

266

def log_model(lgb_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):

267

"""

268

Log LightGBM model as MLflow artifact.

269

270

Parameters:

271

- lgb_model: Trained LightGBM model (Booster, LGBMClassifier, LGBMRegressor)

272

- artifact_path: str - Run-relative artifact path

273

- conda_env: str or dict, optional - Conda environment

274

- code_paths: list, optional - Code dependencies

275

- registered_model_name: str, optional - Registry model name

276

- signature: ModelSignature, optional - Model schema

277

- input_example: Any, optional - Example input

278

- await_registration_for: int - Registration wait time

279

- pip_requirements: list, optional - Pip requirements

280

- extra_pip_requirements: list, optional - Additional requirements

281

- metadata: dict, optional - Custom metadata

282

283

Returns:

284

ModelInfo object

285

"""

286

287

def autolog(importance_type="split", log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None):

288

"""

289

Enable automatic logging for LightGBM training.

290

291

Parameters:

292

- importance_type: str - Feature importance type ("split", "gain")

293

- log_input_examples: bool - Log input examples

294

- log_model_signatures: bool - Log model signatures

295

- log_models: bool - Log trained models

296

- disable: bool - Disable autologging

297

- exclusive: bool - Exclusive autologging mode

298

- disable_for_unsupported_versions: bool - Skip unsupported versions

299

- silent: bool - Suppress warnings

300

- registered_model_name: str, optional - Auto-register model name

301

"""

302

```

303

304

### Transformers Integration

305

306

Hugging Face Transformers integration with support for various model types and tokenizers.

307

308

```python { .api }

309

import mlflow.transformers

310

311

def log_model(transformers_model, artifact_path, task=None, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, tokenizer=None, feature_extractor=None, processor=None, model_config=None, **kwargs):

312

"""

313

Log Transformers model as MLflow artifact.

314

315

Parameters:

316

- transformers_model: Transformers model or pipeline object

317

- artifact_path: str - Run-relative artifact path

318

- task: str, optional - Task type for the model

319

- conda_env: str or dict, optional - Conda environment

320

- code_paths: list, optional - Code dependencies

321

- registered_model_name: str, optional - Registry model name

322

- signature: ModelSignature, optional - Model schema

323

- input_example: Any, optional - Example input

324

- await_registration_for: int - Registration wait time

325

- pip_requirements: list, optional - Pip requirements

326

- extra_pip_requirements: list, optional - Additional requirements

327

- metadata: dict, optional - Custom metadata

328

- tokenizer: Tokenizer, optional - Associated tokenizer

329

- feature_extractor: FeatureExtractor, optional - Feature extractor

330

- processor: Processor, optional - Processor object

331

- model_config: dict, optional - Model configuration

332

333

Returns:

334

ModelInfo object

335

"""

336

337

def load_model(model_uri, dst_path=None, device=None):

338

"""

339

Load Transformers model from MLflow.

340

341

Parameters:

342

- model_uri: str - URI pointing to MLflow model

343

- dst_path: str, optional - Local destination path

344

- device: str or int, optional - Device for model loading

345

346

Returns:

347

Loaded Transformers model or pipeline

348

"""

349

```

350

351

### Spark MLlib Integration

352

353

Apache Spark MLlib integration for distributed machine learning model logging and serving.

354

355

```python { .api }

356

import mlflow.spark

357

358

def log_model(spark_model, artifact_path, conda_env=None, code_paths=None, dfs_tmpdir=None, sample_input=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):

359

"""

360

Log Spark MLlib model as MLflow artifact.

361

362

Parameters:

363

- spark_model: Fitted Spark MLlib model or pipeline

364

- artifact_path: str - Run-relative artifact path

365

- conda_env: str or dict, optional - Conda environment

366

- code_paths: list, optional - Code dependencies

367

- dfs_tmpdir: str, optional - Temporary directory for DFS operations

368

- sample_input: DataFrame, optional - Sample input for schema inference

369

- registered_model_name: str, optional - Registry model name

370

- signature: ModelSignature, optional - Model schema

371

- input_example: Any, optional - Example input

372

- await_registration_for: int - Registration wait time

373

- pip_requirements: list, optional - Pip requirements

374

- extra_pip_requirements: list, optional - Additional requirements

375

- metadata: dict, optional - Custom metadata

376

377

Returns:

378

ModelInfo object

379

"""

380

381

def load_model(model_uri, dfs_tmpdir=None):

382

"""

383

Load Spark MLlib model from MLflow.

384

385

Parameters:

386

- model_uri: str - URI pointing to MLflow model

387

- dfs_tmpdir: str, optional - Temporary directory for DFS

388

389

Returns:

390

Loaded Spark MLlib model or pipeline

391

"""

392

393

import mlflow.pyspark.ml

394

395

def autolog(disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, log_models=True, log_input_examples=False, log_model_signatures=True, log_post_training_metrics=True, registered_model_name=None):

396

"""

397

Enable automatic logging for PySpark ML training.

398

399

Parameters:

400

- disable: bool - Disable autologging

401

- exclusive: bool - Exclusive autologging mode

402

- disable_for_unsupported_versions: bool - Skip unsupported versions

403

- silent: bool - Suppress warnings

404

- log_models: bool - Log trained models

405

- log_input_examples: bool - Log input examples

406

- log_model_signatures: bool - Log model signatures

407

- log_post_training_metrics: bool - Log evaluation metrics

408

- registered_model_name: str, optional - Auto-register model name

409

"""

410

```

411

412

### AG2 (AutoGen) Integration

413

414

Multi-agent conversation framework integration with automatic conversation logging and observability (experimental in MLflow 3.0.0).

415

416

```python { .api }

417

import mlflow.ag2

418

419

def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):

420

"""

421

Enable automatic logging for AG2 (AutoGen) conversations.

422

423

Parameters:

424

- disable: bool - Disable AG2 autologging

425

- log_traces: bool - Log conversation traces

426

- log_models: bool - Log agent models

427

- log_input_examples: bool - Log conversation examples

428

- log_model_signatures: bool - Log model signatures

429

- silent: bool - Suppress autolog warnings

430

"""

431

```

432

433

### Pydantic AI Integration

434

435

Pydantic AI framework integration for structured AI application development with automatic model and conversation logging (experimental in MLflow 3.0.0).

436

437

```python { .api }

438

import mlflow.pydantic_ai

439

440

def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):

441

"""

442

Enable automatic logging for Pydantic AI applications.

443

444

Parameters:

445

- disable: bool - Disable Pydantic AI autologging

446

- log_traces: bool - Log AI application traces

447

- log_models: bool - Log AI models

448

- log_input_examples: bool - Log input examples

449

- log_model_signatures: bool - Log model signatures

450

- silent: bool - Suppress autolog warnings

451

"""

452

```

453

454

### Smolagents Integration

455

456

Smolagents AI agents framework integration with conversation and task execution logging (experimental in MLflow 3.0.0).

457

458

```python { .api }

459

import mlflow.smolagents

460

461

def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):

462

"""

463

Enable automatic logging for Smolagents AI agents.

464

465

Parameters:

466

- disable: bool - Disable Smolagents autologging

467

- log_traces: bool - Log agent execution traces

468

- log_models: bool - Log agent models

469

- log_input_examples: bool - Log input examples

470

- log_model_signatures: bool - Log model signatures

471

- silent: bool - Suppress autolog warnings

472

"""

473

```

474

475

### Groq Integration

476

477

Groq API integration with automatic request/response logging and performance tracking.

478

479

```python { .api }

480

import mlflow.groq

481

482

def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):

483

"""

484

Enable automatic logging for Groq API calls.

485

486

Parameters:

487

- disable: bool - Disable Groq autologging

488

- log_traces: bool - Log API call traces

489

- log_models: bool - Log model configurations

490

- log_input_examples: bool - Log input examples

491

- log_model_signatures: bool - Log model signatures

492

- silent: bool - Suppress autolog warnings

493

"""

494

```

495

496

### Semantic Kernel Integration

497

498

Microsoft Semantic Kernel framework integration for orchestrating AI services with automatic logging and observability.

499

500

```python { .api }

501

import mlflow.semantic_kernel

502

503

def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):

504

"""

505

Enable automatic logging for Semantic Kernel applications.

506

507

Parameters:

508

- disable: bool - Disable Semantic Kernel autologging

509

- log_traces: bool - Log kernel execution traces

510

- log_models: bool - Log AI service configurations

511

- log_input_examples: bool - Log input examples

512

- log_model_signatures: bool - Log model signatures

513

- silent: bool - Suppress autolog warnings

514

"""

515

```

516

517

### Auto-logging Capabilities

518

519

Automatic experiment tracking across supported frameworks with minimal code changes.

520

521

```python { .api }

522

import mlflow

523

524

def autolog(log_input_examples=False, log_model_signatures=True, log_models=True, log_datasets=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, extra_tags=None, registered_model_name=None):

525

"""

526

Enable automatic logging across all supported frameworks.

527

528

Parameters:

529

- log_input_examples: bool - Log input examples for models

530

- log_model_signatures: bool - Log model input/output signatures

531

- log_models: bool - Log trained model objects

532

- log_datasets: bool - Log training/validation datasets

533

- disable: bool - Disable all autologging if True

534

- exclusive: bool - Use exclusive autologging mode

535

- disable_for_unsupported_versions: bool - Skip unsupported library versions

536

- silent: bool - Suppress autolog setup warnings

537

- extra_tags: dict, optional - Additional tags for all runs

538

- registered_model_name: str, optional - Auto-register models with name

539

"""

540

541

# Framework-specific autolog functions

542

def sklearn_autolog(**kwargs):

543

"""Enable scikit-learn autologging."""

544

545

def pytorch_autolog(**kwargs):

546

"""Enable PyTorch autologging."""

547

548

def tensorflow_autolog(**kwargs):

549

"""Enable TensorFlow/Keras autologging."""

550

551

def xgboost_autolog(**kwargs):

552

"""Enable XGBoost autologging."""

553

554

def lightgbm_autolog(**kwargs):

555

"""Enable LightGBM autologging."""

556

557

def spark_autolog(**kwargs):

558

"""Enable Spark MLlib autologging."""

559

```

560

561

## Usage Examples

562

563

### Scikit-learn Model Logging

564

565

```python

566

import mlflow

567

import mlflow.sklearn

568

from sklearn.ensemble import RandomForestClassifier

569

from sklearn.datasets import make_classification

570

from sklearn.model_selection import train_test_split

571

from sklearn.pipeline import Pipeline

572

from sklearn.preprocessing import StandardScaler

573

574

# Generate sample data

575

X, y = make_classification(n_samples=1000, n_features=20, random_state=42)

576

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

577

578

# Create and train pipeline

579

pipeline = Pipeline([

580

('scaler', StandardScaler()),

581

('classifier', RandomForestClassifier(n_estimators=100, random_state=42))

582

])

583

584

mlflow.set_experiment("sklearn_integration")

585

586

with mlflow.start_run():

587

# Train model

588

pipeline.fit(X_train, y_train)

589

590

# Log model with signature and example

591

signature = mlflow.models.infer_signature(X_train, pipeline.predict(X_train))

592

593

mlflow.sklearn.log_model(

594

sk_model=pipeline,

595

artifact_path="model",

596

signature=signature,

597

input_example=X_train[:3],

598

registered_model_name="rf_pipeline"

599

)

600

601

# Log metrics

602

train_score = pipeline.score(X_train, y_train)

603

test_score = pipeline.score(X_test, y_test)

604

605

mlflow.log_metric("train_accuracy", train_score)

606

mlflow.log_metric("test_accuracy", test_score)

607

608

print(f"Model logged with accuracy: {test_score:.3f}")

609

610

# Load and use model

611

model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"

612

loaded_model = mlflow.sklearn.load_model(model_uri)

613

predictions = loaded_model.predict(X_test)

614

```

615

616

### PyTorch Model with Custom Architecture

617

618

```python

619

import mlflow

620

import mlflow.pytorch

621

import torch

622

import torch.nn as nn

623

import torch.optim as optim

624

from torch.utils.data import DataLoader, TensorDataset

625

626

# Define custom model

627

class NeuralNet(nn.Module):

628

def __init__(self, input_size, hidden_size, num_classes):

629

super(NeuralNet, self).__init__()

630

self.fc1 = nn.Linear(input_size, hidden_size)

631

self.relu = nn.ReLU()

632

self.fc2 = nn.Linear(hidden_size, num_classes)

633

634

def forward(self, x):

635

out = self.fc1(x)

636

out = self.relu(out)

637

out = self.fc2(out)

638

return out

639

640

# Prepare data

641

X = torch.randn(1000, 20)

642

y = torch.randint(0, 2, (1000,))

643

dataset = TensorDataset(X, y)

644

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

645

646

mlflow.set_experiment("pytorch_integration")

647

648

with mlflow.start_run():

649

# Initialize model

650

model = NeuralNet(input_size=20, hidden_size=50, num_classes=2)

651

criterion = nn.CrossEntropyLoss()

652

optimizer = optim.Adam(model.parameters(), lr=0.01)

653

654

# Log hyperparameters

655

mlflow.log_param("input_size", 20)

656

mlflow.log_param("hidden_size", 50)

657

mlflow.log_param("learning_rate", 0.01)

658

mlflow.log_param("batch_size", 32)

659

660

# Training loop

661

for epoch in range(10):

662

total_loss = 0

663

for batch_x, batch_y in dataloader:

664

optimizer.zero_grad()

665

outputs = model(batch_x)

666

loss = criterion(outputs, batch_y)

667

loss.backward()

668

optimizer.step()

669

total_loss += loss.item()

670

671

avg_loss = total_loss / len(dataloader)

672

mlflow.log_metric("loss", avg_loss, step=epoch)

673

674

# Log model

675

mlflow.pytorch.log_model(

676

pytorch_model=model,

677

artifact_path="model",

678

registered_model_name="neural_net"

679

)

680

681

# Log state dict separately

682

mlflow.pytorch.log_state_dict(

683

state_dict=model.state_dict(),

684

artifact_path="state_dict"

685

)

686

687

print("PyTorch model logged successfully")

688

689

# Load model

690

model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"

691

loaded_model = mlflow.pytorch.load_model(model_uri)

692

```

693

694

### XGBoost with Autologging

695

696

```python

697

import mlflow

698

import mlflow.xgboost

699

import xgboost as xgb

700

from sklearn.datasets import make_classification

701

from sklearn.model_selection import train_test_split

702

703

# Enable XGBoost autologging

704

mlflow.xgboost.autolog(

705

importance_type="gain",

706

log_input_examples=True,

707

log_model_signatures=True,

708

registered_model_name="xgb_automodel"

709

)

710

711

# Prepare data

712

X, y = make_classification(n_samples=1000, n_features=20, random_state=42)

713

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

714

715

mlflow.set_experiment("xgboost_autolog")

716

717

with mlflow.start_run():

718

# Train XGBoost model - automatically logged

719

model = xgb.XGBClassifier(

720

n_estimators=100,

721

max_depth=6,

722

learning_rate=0.1,

723

random_state=42

724

)

725

726

model.fit(

727

X_train, y_train,

728

eval_set=[(X_test, y_test)],

729

eval_metric="logloss",

730

verbose=False

731

)

732

733

# Additional manual logging

734

test_accuracy = model.score(X_test, y_test)

735

mlflow.log_metric("test_accuracy", test_accuracy)

736

737

print(f"XGBoost model auto-logged with accuracy: {test_accuracy:.3f}")

738

739

# Feature importance is automatically logged

740

# Model is automatically registered with specified name

741

```

742

743

### Transformers with Multiple Components

744

745

```python

746

import mlflow

747

import mlflow.transformers

748

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

749

750

mlflow.set_experiment("transformers_integration")

751

752

with mlflow.start_run():

753

# Load pre-trained model and tokenizer

754

model_name = "distilbert-base-uncased-finetuned-sst-2-english"

755

756

# Load components separately for more control

757

tokenizer = AutoTokenizer.from_pretrained(model_name)

758

model = AutoModelForSequenceClassification.from_pretrained(model_name)

759

760

# Create pipeline

761

sentiment_pipeline = pipeline(

762

"sentiment-analysis",

763

model=model,

764

tokenizer=tokenizer,

765

return_all_scores=True

766

)

767

768

# Log model with all components

769

mlflow.transformers.log_model(

770

transformers_model=sentiment_pipeline,

771

artifact_path="sentiment_model",

772

task="text-classification",

773

tokenizer=tokenizer,

774

model_config={

775

"max_length": 512,

776

"padding": True,

777

"truncation": True

778

},

779

registered_model_name="sentiment_classifier"

780

)

781

782

# Test the pipeline

783

test_texts = [

784

"I love this product!",

785

"This is terrible.",

786

"It's okay, nothing special."

787

]

788

789

results = sentiment_pipeline(test_texts)

790

791

# Log example predictions

792

for text, result in zip(test_texts, results):

793

print(f"'{text}' -> {result}")

794

mlflow.log_text(f"Prediction: {result}", f"example_{hash(text)}.txt")

795

796

print("Transformers model logged with tokenizer and config")

797

798

# Load and use model

799

model_uri = f"runs:/{mlflow.active_run().info.run_id}/sentiment_model"

800

loaded_pipeline = mlflow.transformers.load_model(model_uri)

801

new_predictions = loaded_pipeline(["MLflow is amazing!"])

802

```

803

804

### Spark MLlib Distributed Training

805

806

```python

807

import mlflow

808

import mlflow.spark

809

from pyspark.sql import SparkSession

810

from pyspark.ml.feature import VectorAssembler, StringIndexer

811

from pyspark.ml.classification import RandomForestClassifier

812

from pyspark.ml import Pipeline

813

from pyspark.ml.evaluation import MulticlassClassificationEvaluator

814

815

# Initialize Spark

816

spark = SparkSession.builder.appName("MLflow Spark Integration").getOrCreate()

817

818

# Enable Spark autologging

819

mlflow.spark.autolog(log_models=True, log_input_examples=True)

820

821

mlflow.set_experiment("spark_integration")

822

823

with mlflow.start_run():

824

# Create sample DataFrame

825

data = [(0.0, "a", 1.0, 0),

826

(1.0, "b", 2.0, 1),

827

(2.0, "c", 3.0, 0),

828

(3.0, "a", 4.0, 1)] * 100

829

830

columns = ["feature1", "category", "feature2", "label"]

831

df = spark.createDataFrame(data, columns)

832

833

# Create ML Pipeline

834

indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")

835

assembler = VectorAssembler(

836

inputCols=["feature1", "categoryIndex", "feature2"],

837

outputCol="features"

838

)

839

rf = RandomForestClassifier(featuresCol="features", labelCol="label")

840

841

pipeline = Pipeline(stages=[indexer, assembler, rf])

842

843

# Split data

844

train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

845

846

# Train pipeline - automatically logged

847

model = pipeline.fit(train_df)

848

849

# Make predictions

850

predictions = model.transform(test_df)

851

852

# Evaluate model

853

evaluator = MulticlassClassificationEvaluator(

854

labelCol="label",

855

predictionCol="prediction",

856

metricName="accuracy"

857

)

858

accuracy = evaluator.evaluate(predictions)

859

860

mlflow.log_metric("test_accuracy", accuracy)

861

862

# Log model manually for more control

863

mlflow.spark.log_model(

864

spark_model=model,

865

artifact_path="spark_pipeline",

866

registered_model_name="spark_rf_pipeline"

867

)

868

869

print(f"Spark pipeline logged with accuracy: {accuracy:.3f}")

870

871

spark.stop()

872

```

873

874

### Multi-Framework Comparison

875

876

```python

877

import mlflow

878

import numpy as np

879

from sklearn.datasets import make_classification

880

from sklearn.model_selection import train_test_split

881

from sklearn.ensemble import RandomForestClassifier

882

import xgboost as xgb

883

import lightgbm as lgb

884

885

# Generate data

886

X, y = make_classification(n_samples=10000, n_features=20, random_state=42)

887

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

888

889

mlflow.set_experiment("framework_comparison")

890

891

# Compare multiple frameworks

892

frameworks = {

893

"sklearn": {

894

"model": RandomForestClassifier(n_estimators=100, random_state=42),

895

"log_func": mlflow.sklearn.log_model

896

},

897

"xgboost": {

898

"model": xgb.XGBClassifier(n_estimators=100, random_state=42),

899

"log_func": mlflow.xgboost.log_model

900

},

901

"lightgbm": {

902

"model": lgb.LGBMClassifier(n_estimators=100, random_state=42),

903

"log_func": mlflow.lightgbm.log_model

904

}

905

}

906

907

results = {}

908

909

for framework_name, config in frameworks.items():

910

with mlflow.start_run(run_name=f"{framework_name}_model"):

911

# Train model

912

model = config["model"]

913

model.fit(X_train, y_train)

914

915

# Evaluate

916

train_acc = model.score(X_train, y_train)

917

test_acc = model.score(X_test, y_test)

918

919

# Log metrics

920

mlflow.log_param("framework", framework_name)

921

mlflow.log_metric("train_accuracy", train_acc)

922

mlflow.log_metric("test_accuracy", test_acc)

923

924

# Log model

925

config["log_func"](

926

model,

927

artifact_path="model",

928

registered_model_name=f"{framework_name}_classifier"

929

)

930

931

results[framework_name] = {

932

"train_acc": train_acc,

933

"test_acc": test_acc,

934

"run_id": mlflow.active_run().info.run_id

935

}

936

937

print(f"{framework_name}: Train={train_acc:.3f}, Test={test_acc:.3f}")

938

939

# Find best model

940

best_framework = max(results.keys(), key=lambda k: results[k]["test_acc"])

941

print(f"\nBest framework: {best_framework} (Test Acc: {results[best_framework]['test_acc']:.3f})")

942

```

943

944

### Universal Autologging Setup

945

946

```python

947

import mlflow

948

import warnings

949

950

# Enable universal autologging

951

mlflow.autolog(

952

log_input_examples=True,

953

log_model_signatures=True,

954

log_models=True,

955

log_datasets=True,

956

extra_tags={"environment": "production", "team": "ml-platform"},

957

registered_model_name="auto_registered_model"

958

)

959

960

# Suppress warnings for cleaner output

961

warnings.filterwarnings("ignore")

962

963

mlflow.set_experiment("universal_autolog")

964

965

# Now any supported ML training will be automatically logged

966

from sklearn.ensemble import GradientBoostingClassifier

967

import xgboost as xgb

968

from sklearn.datasets import make_classification

969

970

X, y = make_classification(n_samples=1000, n_features=10, random_state=42)

971

972

# Train multiple models - all automatically logged

973

models = [

974

("sklearn_gb", GradientBoostingClassifier(random_state=42)),

975

("xgboost", xgb.XGBClassifier(random_state=42))

976

]

977

978

for model_name, model in models:

979

with mlflow.start_run(run_name=f"auto_{model_name}"):

980

# Just train - everything else is automatic

981

model.fit(X, y)

982

983

# Only need to log custom metrics if desired

984

custom_score = model.score(X, y)

985

mlflow.log_metric("custom_accuracy", custom_score)

986

987

print(f"{model_name} automatically logged")

988

989

# Disable autologging when done

990

mlflow.autolog(disable=True)

991

```

992

993

## Types

994

995

```python { .api }

996

from typing import Any, Dict, List, Optional, Union

997

import torch

998

import tensorflow as tf

999

from sklearn.base import BaseEstimator

1000

import xgboost

1001

import lightgbm

1002

1003

# Common model types across frameworks

1004

SklearnModel = BaseEstimator

1005

PyTorchModel = torch.nn.Module

1006

TensorFlowModel = Union[tf.keras.Model, str] # Model or SavedModel path

1007

XGBoostModel = Union[xgboost.Booster, xgboost.XGBModel]

1008

LightGBMModel = Union[lightgbm.Booster, lightgbm.LGBMModel]

1009

1010

# Framework-specific logging function signatures

1011

def sklearn_log_model(

1012

sk_model: SklearnModel,

1013

artifact_path: str,

1014

**kwargs

1015

) -> 'ModelInfo': ...

1016

1017

def pytorch_log_model(

1018

pytorch_model: PyTorchModel,

1019

artifact_path: str,

1020

**kwargs

1021

) -> 'ModelInfo': ...

1022

1023

def tensorflow_log_model(

1024

tf_saved_model_dir: str,

1025

artifact_path: str,

1026

**kwargs

1027

) -> 'ModelInfo': ...

1028

1029

def xgboost_log_model(

1030

xgb_model: XGBoostModel,

1031

artifact_path: str,

1032

**kwargs

1033

) -> 'ModelInfo': ...

1034

1035

def lightgbm_log_model(

1036

lgb_model: LightGBMModel,

1037

artifact_path: str,

1038

**kwargs

1039

) -> 'ModelInfo': ...

1040

1041

# Loading function return types

1042

def sklearn_load_model(model_uri: str) -> SklearnModel: ...

1043

def pytorch_load_model(model_uri: str) -> PyTorchModel: ...

1044

def tensorflow_load_model(model_uri: str) -> TensorFlowModel: ...

1045

def xgboost_load_model(model_uri: str) -> XGBoostModel: ...

1046

def lightgbm_load_model(model_uri: str) -> LightGBMModel: ...

1047

1048

# Autolog configuration types

1049

AutologConfig = Dict[str, Union[bool, str, Dict[str, Any]]]

1050

1051

def autolog_function(

1052

log_input_examples: bool = False,

1053

log_model_signatures: bool = True,

1054

log_models: bool = True,

1055

disable: bool = False,

1056

exclusive: bool = False,

1057

disable_for_unsupported_versions: bool = False,

1058

silent: bool = False,

1059

registered_model_name: Optional[str] = None,

1060

**kwargs

1061

) -> None: ...

1062

1063

# Framework-specific types

1064

class TorchStateDict:

1065

"""PyTorch model state dictionary type."""

1066

pass

1067

1068

class SparkPipeline:

1069

"""Spark ML Pipeline type."""

1070

pass

1071

1072

class TransformersPipeline:

1073

"""Hugging Face Transformers Pipeline type."""

1074

pass

1075

1076

# Serialization format constants

1077

SERIALIZATION_FORMAT_PICKLE = "pickle"

1078

SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle"

1079

SERIALIZATION_FORMAT_JSON = "json"

1080

```