0
# ML Framework Integrations
1
2
MLflow provides comprehensive integrations with popular machine learning and deep learning frameworks, enabling seamless model logging, loading, and deployment across different ML ecosystems. Each integration offers framework-specific optimizations and native model format support.
3
4
## Capabilities
5
6
### Scikit-learn Integration
7
8
Native integration for scikit-learn models with automatic dependency management and preprocessing pipeline support.
9
10
```python { .api }
11
import mlflow.sklearn
12
13
def log_model(sk_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, serialization_format=SERIALIZATION_FORMAT_PICKLE, metadata=None, **kwargs):
14
"""
15
Log scikit-learn model as MLflow artifact.
16
17
Parameters:
18
- sk_model: Trained scikit-learn model object
19
- artifact_path: str - Run-relative artifact path
20
- conda_env: str or dict, optional - Conda environment specification
21
- code_paths: list, optional - List of local code paths to include
22
- registered_model_name: str, optional - Name for model registry
23
- signature: ModelSignature, optional - Model input/output schema
24
- input_example: Any, optional - Example input for inference
25
- await_registration_for: int - Seconds to wait for registration
26
- pip_requirements: list, optional - List of pip package requirements
27
- extra_pip_requirements: list, optional - Additional pip requirements
28
- serialization_format: str - Serialization format (pickle, cloudpickle)
29
- metadata: dict, optional - Custom model metadata
30
31
Returns:
32
ModelInfo object with logged model details
33
"""
34
35
def load_model(model_uri, dst_path=None):
36
"""
37
Load scikit-learn model from MLflow.
38
39
Parameters:
40
- model_uri: str - URI pointing to MLflow model
41
- dst_path: str, optional - Local destination path
42
43
Returns:
44
Loaded scikit-learn model object
45
"""
46
47
def save_model(sk_model, path, conda_env=None, code_paths=None, mlflow_model=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, serialization_format=SERIALIZATION_FORMAT_PICKLE, metadata=None):
48
"""
49
Save scikit-learn model to local path.
50
51
Parameters:
52
- sk_model: Trained scikit-learn model object
53
- path: str - Local path to save model
54
- conda_env: str or dict, optional - Conda environment
55
- code_paths: list, optional - Code dependencies to include
56
- mlflow_model: Model, optional - MLflow model configuration
57
- signature: ModelSignature, optional - Model signature
58
- input_example: Any, optional - Example input
59
- pip_requirements: list, optional - Pip package requirements
60
- extra_pip_requirements: list, optional - Additional pip requirements
61
- serialization_format: str - Serialization format
62
- metadata: dict, optional - Custom metadata
63
"""
64
```
65
66
### PyTorch Integration
67
68
Comprehensive PyTorch support including standard models, PyTorch Lightning, and TorchScript compilation.
69
70
```python { .api }
71
import mlflow.pytorch
72
73
def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, requirements_file=None, extra_files=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):
74
"""
75
Log PyTorch model as MLflow artifact.
76
77
Parameters:
78
- pytorch_model: PyTorch model object or state_dict
79
- artifact_path: str - Run-relative artifact path
80
- conda_env: str or dict, optional - Conda environment
81
- code_paths: list, optional - Local code paths to include
82
- pickle_module: module, optional - Module for model serialization
83
- registered_model_name: str, optional - Registry model name
84
- signature: ModelSignature, optional - Model schema
85
- input_example: Any, optional - Example model input
86
- await_registration_for: int - Registration wait time
87
- requirements_file: str, optional - Path to requirements file
88
- extra_files: list, optional - Additional files to include
89
- pip_requirements: list, optional - Pip requirements
90
- extra_pip_requirements: list, optional - Additional pip requirements
91
- metadata: dict, optional - Custom metadata
92
93
Returns:
94
ModelInfo object
95
"""
96
97
def load_model(model_uri, map_location=None, dst_path=None):
98
"""
99
Load PyTorch model from MLflow.
100
101
Parameters:
102
- model_uri: str - URI pointing to MLflow model
103
- map_location: str or torch.device, optional - Device mapping for loading
104
- dst_path: str, optional - Local destination path
105
106
Returns:
107
Loaded PyTorch model object
108
"""
109
110
def log_state_dict(state_dict, artifact_path, **kwargs):
111
"""
112
Log PyTorch model state dictionary.
113
114
Parameters:
115
- state_dict: dict - PyTorch model state dictionary
116
- artifact_path: str - Artifact path for state dict
117
- kwargs: Additional logging arguments
118
"""
119
120
def load_state_dict(model_uri, map_location=None):
121
"""
122
Load PyTorch state dictionary from MLflow.
123
124
Parameters:
125
- model_uri: str - URI pointing to saved state dict
126
- map_location: str or device, optional - Device for loading
127
128
Returns:
129
PyTorch state dictionary
130
"""
131
```
132
133
### TensorFlow Integration
134
135
Full TensorFlow support including Keras models, SavedModel format, and TensorFlow Serving compatibility.
136
137
```python { .api }
138
import mlflow.tensorflow
139
140
def log_model(tf_saved_model_dir=None, tf_meta_graph_tags=None, tf_signature_def_key=None, artifact_path=None, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):
141
"""
142
Log TensorFlow model as MLflow artifact.
143
144
Parameters:
145
- tf_saved_model_dir: str - Path to TensorFlow SavedModel directory
146
- tf_meta_graph_tags: list, optional - MetaGraph tags to load
147
- tf_signature_def_key: str, optional - SignatureDef key for inference
148
- artifact_path: str - Run-relative artifact path
149
- conda_env: str or dict, optional - Conda environment
150
- code_paths: list, optional - Code dependencies
151
- registered_model_name: str, optional - Registry model name
152
- signature: ModelSignature, optional - Model schema
153
- input_example: Any, optional - Example input
154
- await_registration_for: int - Registration wait time
155
- pip_requirements: list, optional - Pip requirements
156
- extra_pip_requirements: list, optional - Additional pip requirements
157
- metadata: dict, optional - Custom metadata
158
159
Returns:
160
ModelInfo object
161
"""
162
163
def load_model(model_uri, dst_path=None):
164
"""
165
Load TensorFlow model from MLflow.
166
167
Parameters:
168
- model_uri: str - URI pointing to MLflow model
169
- dst_path: str, optional - Local destination path
170
171
Returns:
172
Loaded TensorFlow model object
173
"""
174
175
import mlflow.keras
176
177
def log_model(keras_model, artifact_path, conda_env=None, code_paths=None, custom_objects=None, keras_module=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):
178
"""
179
Log Keras model as MLflow artifact.
180
181
Parameters:
182
- keras_model: Compiled Keras model object
183
- artifact_path: str - Run-relative artifact path
184
- conda_env: str or dict, optional - Conda environment
185
- code_paths: list, optional - Code dependencies
186
- custom_objects: dict, optional - Custom objects for model loading
187
- keras_module: module, optional - Keras module for compatibility
188
- registered_model_name: str, optional - Registry model name
189
- signature: ModelSignature, optional - Model schema
190
- input_example: Any, optional - Example input
191
- await_registration_for: int - Registration wait time
192
- pip_requirements: list, optional - Pip requirements
193
- extra_pip_requirements: list, optional - Additional pip requirements
194
- metadata: dict, optional - Custom metadata
195
196
Returns:
197
ModelInfo object
198
"""
199
```
200
201
### XGBoost Integration
202
203
Native XGBoost model support with automatic hyperparameter tracking and feature importance logging.
204
205
```python { .api }
206
import mlflow.xgboost
207
208
def log_model(xgb_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, model_format="xgb", metadata=None, **kwargs):
209
"""
210
Log XGBoost model as MLflow artifact.
211
212
Parameters:
213
- xgb_model: Trained XGBoost model (Booster, XGBClassifier, XGBRegressor)
214
- artifact_path: str - Run-relative artifact path
215
- conda_env: str or dict, optional - Conda environment
216
- code_paths: list, optional - Code dependencies
217
- registered_model_name: str, optional - Registry model name
218
- signature: ModelSignature, optional - Model schema
219
- input_example: Any, optional - Example input
220
- await_registration_for: int - Registration wait time
221
- pip_requirements: list, optional - Pip requirements
222
- extra_pip_requirements: list, optional - Additional requirements
223
- model_format: str - Save format ("xgb", "json", "ubj")
224
- metadata: dict, optional - Custom metadata
225
226
Returns:
227
ModelInfo object
228
"""
229
230
def load_model(model_uri, dst_path=None):
231
"""
232
Load XGBoost model from MLflow.
233
234
Parameters:
235
- model_uri: str - URI pointing to MLflow model
236
- dst_path: str, optional - Local destination path
237
238
Returns:
239
Loaded XGBoost model object
240
"""
241
242
def autolog(importance_type="weight", log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None):
243
"""
244
Enable automatic logging for XGBoost training.
245
246
Parameters:
247
- importance_type: str - Feature importance type to log
248
- log_input_examples: bool - Whether to log input examples
249
- log_model_signatures: bool - Whether to log model signatures
250
- log_models: bool - Whether to log trained models
251
- disable: bool - Disable autologging if True
252
- exclusive: bool - Exclusive autologging mode
253
- disable_for_unsupported_versions: bool - Skip for unsupported versions
254
- silent: bool - Suppress autolog warnings
255
- registered_model_name: str, optional - Auto-register model name
256
"""
257
```
258
259
### LightGBM Integration
260
261
Comprehensive LightGBM support with early stopping integration and automatic metric logging.
262
263
```python { .api }
264
import mlflow.lightgbm
265
266
def log_model(lgb_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):
267
"""
268
Log LightGBM model as MLflow artifact.
269
270
Parameters:
271
- lgb_model: Trained LightGBM model (Booster, LGBMClassifier, LGBMRegressor)
272
- artifact_path: str - Run-relative artifact path
273
- conda_env: str or dict, optional - Conda environment
274
- code_paths: list, optional - Code dependencies
275
- registered_model_name: str, optional - Registry model name
276
- signature: ModelSignature, optional - Model schema
277
- input_example: Any, optional - Example input
278
- await_registration_for: int - Registration wait time
279
- pip_requirements: list, optional - Pip requirements
280
- extra_pip_requirements: list, optional - Additional requirements
281
- metadata: dict, optional - Custom metadata
282
283
Returns:
284
ModelInfo object
285
"""
286
287
def autolog(importance_type="split", log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None):
288
"""
289
Enable automatic logging for LightGBM training.
290
291
Parameters:
292
- importance_type: str - Feature importance type ("split", "gain")
293
- log_input_examples: bool - Log input examples
294
- log_model_signatures: bool - Log model signatures
295
- log_models: bool - Log trained models
296
- disable: bool - Disable autologging
297
- exclusive: bool - Exclusive autologging mode
298
- disable_for_unsupported_versions: bool - Skip unsupported versions
299
- silent: bool - Suppress warnings
300
- registered_model_name: str, optional - Auto-register model name
301
"""
302
```
303
304
### Transformers Integration
305
306
Hugging Face Transformers integration with support for various model types and tokenizers.
307
308
```python { .api }
309
import mlflow.transformers
310
311
def log_model(transformers_model, artifact_path, task=None, conda_env=None, code_paths=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, tokenizer=None, feature_extractor=None, processor=None, model_config=None, **kwargs):
312
"""
313
Log Transformers model as MLflow artifact.
314
315
Parameters:
316
- transformers_model: Transformers model or pipeline object
317
- artifact_path: str - Run-relative artifact path
318
- task: str, optional - Task type for the model
319
- conda_env: str or dict, optional - Conda environment
320
- code_paths: list, optional - Code dependencies
321
- registered_model_name: str, optional - Registry model name
322
- signature: ModelSignature, optional - Model schema
323
- input_example: Any, optional - Example input
324
- await_registration_for: int - Registration wait time
325
- pip_requirements: list, optional - Pip requirements
326
- extra_pip_requirements: list, optional - Additional requirements
327
- metadata: dict, optional - Custom metadata
328
- tokenizer: Tokenizer, optional - Associated tokenizer
329
- feature_extractor: FeatureExtractor, optional - Feature extractor
330
- processor: Processor, optional - Processor object
331
- model_config: dict, optional - Model configuration
332
333
Returns:
334
ModelInfo object
335
"""
336
337
def load_model(model_uri, dst_path=None, device=None):
338
"""
339
Load Transformers model from MLflow.
340
341
Parameters:
342
- model_uri: str - URI pointing to MLflow model
343
- dst_path: str, optional - Local destination path
344
- device: str or int, optional - Device for model loading
345
346
Returns:
347
Loaded Transformers model or pipeline
348
"""
349
```
350
351
### Spark MLlib Integration
352
353
Apache Spark MLlib integration for distributed machine learning model logging and serving.
354
355
```python { .api }
356
import mlflow.spark
357
358
def log_model(spark_model, artifact_path, conda_env=None, code_paths=None, dfs_tmpdir=None, sample_input=None, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs):
359
"""
360
Log Spark MLlib model as MLflow artifact.
361
362
Parameters:
363
- spark_model: Fitted Spark MLlib model or pipeline
364
- artifact_path: str - Run-relative artifact path
365
- conda_env: str or dict, optional - Conda environment
366
- code_paths: list, optional - Code dependencies
367
- dfs_tmpdir: str, optional - Temporary directory for DFS operations
368
- sample_input: DataFrame, optional - Sample input for schema inference
369
- registered_model_name: str, optional - Registry model name
370
- signature: ModelSignature, optional - Model schema
371
- input_example: Any, optional - Example input
372
- await_registration_for: int - Registration wait time
373
- pip_requirements: list, optional - Pip requirements
374
- extra_pip_requirements: list, optional - Additional requirements
375
- metadata: dict, optional - Custom metadata
376
377
Returns:
378
ModelInfo object
379
"""
380
381
def load_model(model_uri, dfs_tmpdir=None):
382
"""
383
Load Spark MLlib model from MLflow.
384
385
Parameters:
386
- model_uri: str - URI pointing to MLflow model
387
- dfs_tmpdir: str, optional - Temporary directory for DFS
388
389
Returns:
390
Loaded Spark MLlib model or pipeline
391
"""
392
393
import mlflow.pyspark.ml
394
395
def autolog(disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, log_models=True, log_input_examples=False, log_model_signatures=True, log_post_training_metrics=True, registered_model_name=None):
396
"""
397
Enable automatic logging for PySpark ML training.
398
399
Parameters:
400
- disable: bool - Disable autologging
401
- exclusive: bool - Exclusive autologging mode
402
- disable_for_unsupported_versions: bool - Skip unsupported versions
403
- silent: bool - Suppress warnings
404
- log_models: bool - Log trained models
405
- log_input_examples: bool - Log input examples
406
- log_model_signatures: bool - Log model signatures
407
- log_post_training_metrics: bool - Log evaluation metrics
408
- registered_model_name: str, optional - Auto-register model name
409
"""
410
```
411
412
### AG2 (AutoGen) Integration
413
414
Multi-agent conversation framework integration with automatic conversation logging and observability (experimental in MLflow 3.0.0).
415
416
```python { .api }
417
import mlflow.ag2
418
419
def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):
420
"""
421
Enable automatic logging for AG2 (AutoGen) conversations.
422
423
Parameters:
424
- disable: bool - Disable AG2 autologging
425
- log_traces: bool - Log conversation traces
426
- log_models: bool - Log agent models
427
- log_input_examples: bool - Log conversation examples
428
- log_model_signatures: bool - Log model signatures
429
- silent: bool - Suppress autolog warnings
430
"""
431
```
432
433
### Pydantic AI Integration
434
435
Pydantic AI framework integration for structured AI application development with automatic model and conversation logging (experimental in MLflow 3.0.0).
436
437
```python { .api }
438
import mlflow.pydantic_ai
439
440
def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):
441
"""
442
Enable automatic logging for Pydantic AI applications.
443
444
Parameters:
445
- disable: bool - Disable Pydantic AI autologging
446
- log_traces: bool - Log AI application traces
447
- log_models: bool - Log AI models
448
- log_input_examples: bool - Log input examples
449
- log_model_signatures: bool - Log model signatures
450
- silent: bool - Suppress autolog warnings
451
"""
452
```
453
454
### Smolagents Integration
455
456
Smolagents AI agents framework integration with conversation and task execution logging (experimental in MLflow 3.0.0).
457
458
```python { .api }
459
import mlflow.smolagents
460
461
def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):
462
"""
463
Enable automatic logging for Smolagents AI agents.
464
465
Parameters:
466
- disable: bool - Disable Smolagents autologging
467
- log_traces: bool - Log agent execution traces
468
- log_models: bool - Log agent models
469
- log_input_examples: bool - Log input examples
470
- log_model_signatures: bool - Log model signatures
471
- silent: bool - Suppress autolog warnings
472
"""
473
```
474
475
### Groq Integration
476
477
Groq API integration with automatic request/response logging and performance tracking.
478
479
```python { .api }
480
import mlflow.groq
481
482
def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):
483
"""
484
Enable automatic logging for Groq API calls.
485
486
Parameters:
487
- disable: bool - Disable Groq autologging
488
- log_traces: bool - Log API call traces
489
- log_models: bool - Log model configurations
490
- log_input_examples: bool - Log input examples
491
- log_model_signatures: bool - Log model signatures
492
- silent: bool - Suppress autolog warnings
493
"""
494
```
495
496
### Semantic Kernel Integration
497
498
Microsoft Semantic Kernel framework integration for orchestrating AI services with automatic logging and observability.
499
500
```python { .api }
501
import mlflow.semantic_kernel
502
503
def autolog(disable=False, log_traces=True, log_models=False, log_input_examples=False, log_model_signatures=True, silent=False):
504
"""
505
Enable automatic logging for Semantic Kernel applications.
506
507
Parameters:
508
- disable: bool - Disable Semantic Kernel autologging
509
- log_traces: bool - Log kernel execution traces
510
- log_models: bool - Log AI service configurations
511
- log_input_examples: bool - Log input examples
512
- log_model_signatures: bool - Log model signatures
513
- silent: bool - Suppress autolog warnings
514
"""
515
```
516
517
### Auto-logging Capabilities
518
519
Automatic experiment tracking across supported frameworks with minimal code changes.
520
521
```python { .api }
522
import mlflow
523
524
def autolog(log_input_examples=False, log_model_signatures=True, log_models=True, log_datasets=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, extra_tags=None, registered_model_name=None):
525
"""
526
Enable automatic logging across all supported frameworks.
527
528
Parameters:
529
- log_input_examples: bool - Log input examples for models
530
- log_model_signatures: bool - Log model input/output signatures
531
- log_models: bool - Log trained model objects
532
- log_datasets: bool - Log training/validation datasets
533
- disable: bool - Disable all autologging if True
534
- exclusive: bool - Use exclusive autologging mode
535
- disable_for_unsupported_versions: bool - Skip unsupported library versions
536
- silent: bool - Suppress autolog setup warnings
537
- extra_tags: dict, optional - Additional tags for all runs
538
- registered_model_name: str, optional - Auto-register models with name
539
"""
540
541
# Framework-specific autolog functions
542
def sklearn_autolog(**kwargs):
543
"""Enable scikit-learn autologging."""
544
545
def pytorch_autolog(**kwargs):
546
"""Enable PyTorch autologging."""
547
548
def tensorflow_autolog(**kwargs):
549
"""Enable TensorFlow/Keras autologging."""
550
551
def xgboost_autolog(**kwargs):
552
"""Enable XGBoost autologging."""
553
554
def lightgbm_autolog(**kwargs):
555
"""Enable LightGBM autologging."""
556
557
def spark_autolog(**kwargs):
558
"""Enable Spark MLlib autologging."""
559
```
560
561
## Usage Examples
562
563
### Scikit-learn Model Logging
564
565
```python
566
import mlflow
567
import mlflow.sklearn
568
from sklearn.ensemble import RandomForestClassifier
569
from sklearn.datasets import make_classification
570
from sklearn.model_selection import train_test_split
571
from sklearn.pipeline import Pipeline
572
from sklearn.preprocessing import StandardScaler
573
574
# Generate sample data
575
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
576
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
577
578
# Create and train pipeline
579
pipeline = Pipeline([
580
('scaler', StandardScaler()),
581
('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
582
])
583
584
mlflow.set_experiment("sklearn_integration")
585
586
with mlflow.start_run():
587
# Train model
588
pipeline.fit(X_train, y_train)
589
590
# Log model with signature and example
591
signature = mlflow.models.infer_signature(X_train, pipeline.predict(X_train))
592
593
mlflow.sklearn.log_model(
594
sk_model=pipeline,
595
artifact_path="model",
596
signature=signature,
597
input_example=X_train[:3],
598
registered_model_name="rf_pipeline"
599
)
600
601
# Log metrics
602
train_score = pipeline.score(X_train, y_train)
603
test_score = pipeline.score(X_test, y_test)
604
605
mlflow.log_metric("train_accuracy", train_score)
606
mlflow.log_metric("test_accuracy", test_score)
607
608
print(f"Model logged with accuracy: {test_score:.3f}")
609
610
# Load and use model
611
model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
612
loaded_model = mlflow.sklearn.load_model(model_uri)
613
predictions = loaded_model.predict(X_test)
614
```
615
616
### PyTorch Model with Custom Architecture
617
618
```python
619
import mlflow
620
import mlflow.pytorch
621
import torch
622
import torch.nn as nn
623
import torch.optim as optim
624
from torch.utils.data import DataLoader, TensorDataset
625
626
# Define custom model
627
class NeuralNet(nn.Module):
628
def __init__(self, input_size, hidden_size, num_classes):
629
super(NeuralNet, self).__init__()
630
self.fc1 = nn.Linear(input_size, hidden_size)
631
self.relu = nn.ReLU()
632
self.fc2 = nn.Linear(hidden_size, num_classes)
633
634
def forward(self, x):
635
out = self.fc1(x)
636
out = self.relu(out)
637
out = self.fc2(out)
638
return out
639
640
# Prepare data
641
X = torch.randn(1000, 20)
642
y = torch.randint(0, 2, (1000,))
643
dataset = TensorDataset(X, y)
644
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
645
646
mlflow.set_experiment("pytorch_integration")
647
648
with mlflow.start_run():
649
# Initialize model
650
model = NeuralNet(input_size=20, hidden_size=50, num_classes=2)
651
criterion = nn.CrossEntropyLoss()
652
optimizer = optim.Adam(model.parameters(), lr=0.01)
653
654
# Log hyperparameters
655
mlflow.log_param("input_size", 20)
656
mlflow.log_param("hidden_size", 50)
657
mlflow.log_param("learning_rate", 0.01)
658
mlflow.log_param("batch_size", 32)
659
660
# Training loop
661
for epoch in range(10):
662
total_loss = 0
663
for batch_x, batch_y in dataloader:
664
optimizer.zero_grad()
665
outputs = model(batch_x)
666
loss = criterion(outputs, batch_y)
667
loss.backward()
668
optimizer.step()
669
total_loss += loss.item()
670
671
avg_loss = total_loss / len(dataloader)
672
mlflow.log_metric("loss", avg_loss, step=epoch)
673
674
# Log model
675
mlflow.pytorch.log_model(
676
pytorch_model=model,
677
artifact_path="model",
678
registered_model_name="neural_net"
679
)
680
681
# Log state dict separately
682
mlflow.pytorch.log_state_dict(
683
state_dict=model.state_dict(),
684
artifact_path="state_dict"
685
)
686
687
print("PyTorch model logged successfully")
688
689
# Load model
690
model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
691
loaded_model = mlflow.pytorch.load_model(model_uri)
692
```
693
694
### XGBoost with Autologging
695
696
```python
697
import mlflow
698
import mlflow.xgboost
699
import xgboost as xgb
700
from sklearn.datasets import make_classification
701
from sklearn.model_selection import train_test_split
702
703
# Enable XGBoost autologging
704
mlflow.xgboost.autolog(
705
importance_type="gain",
706
log_input_examples=True,
707
log_model_signatures=True,
708
registered_model_name="xgb_automodel"
709
)
710
711
# Prepare data
712
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
713
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
714
715
mlflow.set_experiment("xgboost_autolog")
716
717
with mlflow.start_run():
718
# Train XGBoost model - automatically logged
719
model = xgb.XGBClassifier(
720
n_estimators=100,
721
max_depth=6,
722
learning_rate=0.1,
723
random_state=42
724
)
725
726
model.fit(
727
X_train, y_train,
728
eval_set=[(X_test, y_test)],
729
eval_metric="logloss",
730
verbose=False
731
)
732
733
# Additional manual logging
734
test_accuracy = model.score(X_test, y_test)
735
mlflow.log_metric("test_accuracy", test_accuracy)
736
737
print(f"XGBoost model auto-logged with accuracy: {test_accuracy:.3f}")
738
739
# Feature importance is automatically logged
740
# Model is automatically registered with specified name
741
```
742
743
### Transformers with Multiple Components
744
745
```python
746
import mlflow
747
import mlflow.transformers
748
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
749
750
mlflow.set_experiment("transformers_integration")
751
752
with mlflow.start_run():
753
# Load pre-trained model and tokenizer
754
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
755
756
# Load components separately for more control
757
tokenizer = AutoTokenizer.from_pretrained(model_name)
758
model = AutoModelForSequenceClassification.from_pretrained(model_name)
759
760
# Create pipeline
761
sentiment_pipeline = pipeline(
762
"sentiment-analysis",
763
model=model,
764
tokenizer=tokenizer,
765
return_all_scores=True
766
)
767
768
# Log model with all components
769
mlflow.transformers.log_model(
770
transformers_model=sentiment_pipeline,
771
artifact_path="sentiment_model",
772
task="text-classification",
773
tokenizer=tokenizer,
774
model_config={
775
"max_length": 512,
776
"padding": True,
777
"truncation": True
778
},
779
registered_model_name="sentiment_classifier"
780
)
781
782
# Test the pipeline
783
test_texts = [
784
"I love this product!",
785
"This is terrible.",
786
"It's okay, nothing special."
787
]
788
789
results = sentiment_pipeline(test_texts)
790
791
# Log example predictions
792
for text, result in zip(test_texts, results):
793
print(f"'{text}' -> {result}")
794
mlflow.log_text(f"Prediction: {result}", f"example_{hash(text)}.txt")
795
796
print("Transformers model logged with tokenizer and config")
797
798
# Load and use model
799
model_uri = f"runs:/{mlflow.active_run().info.run_id}/sentiment_model"
800
loaded_pipeline = mlflow.transformers.load_model(model_uri)
801
new_predictions = loaded_pipeline(["MLflow is amazing!"])
802
```
803
804
### Spark MLlib Distributed Training
805
806
```python
807
import mlflow
808
import mlflow.spark
809
from pyspark.sql import SparkSession
810
from pyspark.ml.feature import VectorAssembler, StringIndexer
811
from pyspark.ml.classification import RandomForestClassifier
812
from pyspark.ml import Pipeline
813
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
814
815
# Initialize Spark
816
spark = SparkSession.builder.appName("MLflow Spark Integration").getOrCreate()
817
818
# Enable Spark autologging
819
mlflow.spark.autolog(log_models=True, log_input_examples=True)
820
821
mlflow.set_experiment("spark_integration")
822
823
with mlflow.start_run():
824
# Create sample DataFrame
825
data = [(0.0, "a", 1.0, 0),
826
(1.0, "b", 2.0, 1),
827
(2.0, "c", 3.0, 0),
828
(3.0, "a", 4.0, 1)] * 100
829
830
columns = ["feature1", "category", "feature2", "label"]
831
df = spark.createDataFrame(data, columns)
832
833
# Create ML Pipeline
834
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
835
assembler = VectorAssembler(
836
inputCols=["feature1", "categoryIndex", "feature2"],
837
outputCol="features"
838
)
839
rf = RandomForestClassifier(featuresCol="features", labelCol="label")
840
841
pipeline = Pipeline(stages=[indexer, assembler, rf])
842
843
# Split data
844
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
845
846
# Train pipeline - automatically logged
847
model = pipeline.fit(train_df)
848
849
# Make predictions
850
predictions = model.transform(test_df)
851
852
# Evaluate model
853
evaluator = MulticlassClassificationEvaluator(
854
labelCol="label",
855
predictionCol="prediction",
856
metricName="accuracy"
857
)
858
accuracy = evaluator.evaluate(predictions)
859
860
mlflow.log_metric("test_accuracy", accuracy)
861
862
# Log model manually for more control
863
mlflow.spark.log_model(
864
spark_model=model,
865
artifact_path="spark_pipeline",
866
registered_model_name="spark_rf_pipeline"
867
)
868
869
print(f"Spark pipeline logged with accuracy: {accuracy:.3f}")
870
871
spark.stop()
872
```
873
874
### Multi-Framework Comparison
875
876
```python
877
import mlflow
878
import numpy as np
879
from sklearn.datasets import make_classification
880
from sklearn.model_selection import train_test_split
881
from sklearn.ensemble import RandomForestClassifier
882
import xgboost as xgb
883
import lightgbm as lgb
884
885
# Generate data
886
X, y = make_classification(n_samples=10000, n_features=20, random_state=42)
887
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
888
889
mlflow.set_experiment("framework_comparison")
890
891
# Compare multiple frameworks
892
frameworks = {
893
"sklearn": {
894
"model": RandomForestClassifier(n_estimators=100, random_state=42),
895
"log_func": mlflow.sklearn.log_model
896
},
897
"xgboost": {
898
"model": xgb.XGBClassifier(n_estimators=100, random_state=42),
899
"log_func": mlflow.xgboost.log_model
900
},
901
"lightgbm": {
902
"model": lgb.LGBMClassifier(n_estimators=100, random_state=42),
903
"log_func": mlflow.lightgbm.log_model
904
}
905
}
906
907
results = {}
908
909
for framework_name, config in frameworks.items():
910
with mlflow.start_run(run_name=f"{framework_name}_model"):
911
# Train model
912
model = config["model"]
913
model.fit(X_train, y_train)
914
915
# Evaluate
916
train_acc = model.score(X_train, y_train)
917
test_acc = model.score(X_test, y_test)
918
919
# Log metrics
920
mlflow.log_param("framework", framework_name)
921
mlflow.log_metric("train_accuracy", train_acc)
922
mlflow.log_metric("test_accuracy", test_acc)
923
924
# Log model
925
config["log_func"](
926
model,
927
artifact_path="model",
928
registered_model_name=f"{framework_name}_classifier"
929
)
930
931
results[framework_name] = {
932
"train_acc": train_acc,
933
"test_acc": test_acc,
934
"run_id": mlflow.active_run().info.run_id
935
}
936
937
print(f"{framework_name}: Train={train_acc:.3f}, Test={test_acc:.3f}")
938
939
# Find best model
940
best_framework = max(results.keys(), key=lambda k: results[k]["test_acc"])
941
print(f"\nBest framework: {best_framework} (Test Acc: {results[best_framework]['test_acc']:.3f})")
942
```
943
944
### Universal Autologging Setup
945
946
```python
947
import mlflow
948
import warnings
949
950
# Enable universal autologging
951
mlflow.autolog(
952
log_input_examples=True,
953
log_model_signatures=True,
954
log_models=True,
955
log_datasets=True,
956
extra_tags={"environment": "production", "team": "ml-platform"},
957
registered_model_name="auto_registered_model"
958
)
959
960
# Suppress warnings for cleaner output
961
warnings.filterwarnings("ignore")
962
963
mlflow.set_experiment("universal_autolog")
964
965
# Now any supported ML training will be automatically logged
966
from sklearn.ensemble import GradientBoostingClassifier
967
import xgboost as xgb
968
from sklearn.datasets import make_classification
969
970
X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
971
972
# Train multiple models - all automatically logged
973
models = [
974
("sklearn_gb", GradientBoostingClassifier(random_state=42)),
975
("xgboost", xgb.XGBClassifier(random_state=42))
976
]
977
978
for model_name, model in models:
979
with mlflow.start_run(run_name=f"auto_{model_name}"):
980
# Just train - everything else is automatic
981
model.fit(X, y)
982
983
# Only need to log custom metrics if desired
984
custom_score = model.score(X, y)
985
mlflow.log_metric("custom_accuracy", custom_score)
986
987
print(f"{model_name} automatically logged")
988
989
# Disable autologging when done
990
mlflow.autolog(disable=True)
991
```
992
993
## Types
994
995
```python { .api }
996
from typing import Any, Dict, List, Optional, Union
997
import torch
998
import tensorflow as tf
999
from sklearn.base import BaseEstimator
1000
import xgboost
1001
import lightgbm
1002
1003
# Common model types across frameworks
1004
SklearnModel = BaseEstimator
1005
PyTorchModel = torch.nn.Module
1006
TensorFlowModel = Union[tf.keras.Model, str] # Model or SavedModel path
1007
XGBoostModel = Union[xgboost.Booster, xgboost.XGBModel]
1008
LightGBMModel = Union[lightgbm.Booster, lightgbm.LGBMModel]
1009
1010
# Framework-specific logging function signatures
1011
def sklearn_log_model(
1012
sk_model: SklearnModel,
1013
artifact_path: str,
1014
**kwargs
1015
) -> 'ModelInfo': ...
1016
1017
def pytorch_log_model(
1018
pytorch_model: PyTorchModel,
1019
artifact_path: str,
1020
**kwargs
1021
) -> 'ModelInfo': ...
1022
1023
def tensorflow_log_model(
1024
tf_saved_model_dir: str,
1025
artifact_path: str,
1026
**kwargs
1027
) -> 'ModelInfo': ...
1028
1029
def xgboost_log_model(
1030
xgb_model: XGBoostModel,
1031
artifact_path: str,
1032
**kwargs
1033
) -> 'ModelInfo': ...
1034
1035
def lightgbm_log_model(
1036
lgb_model: LightGBMModel,
1037
artifact_path: str,
1038
**kwargs
1039
) -> 'ModelInfo': ...
1040
1041
# Loading function return types
1042
def sklearn_load_model(model_uri: str) -> SklearnModel: ...
1043
def pytorch_load_model(model_uri: str) -> PyTorchModel: ...
1044
def tensorflow_load_model(model_uri: str) -> TensorFlowModel: ...
1045
def xgboost_load_model(model_uri: str) -> XGBoostModel: ...
1046
def lightgbm_load_model(model_uri: str) -> LightGBMModel: ...
1047
1048
# Autolog configuration types
1049
AutologConfig = Dict[str, Union[bool, str, Dict[str, Any]]]
1050
1051
def autolog_function(
1052
log_input_examples: bool = False,
1053
log_model_signatures: bool = True,
1054
log_models: bool = True,
1055
disable: bool = False,
1056
exclusive: bool = False,
1057
disable_for_unsupported_versions: bool = False,
1058
silent: bool = False,
1059
registered_model_name: Optional[str] = None,
1060
**kwargs
1061
) -> None: ...
1062
1063
# Framework-specific types
1064
class TorchStateDict:
1065
"""PyTorch model state dictionary type."""
1066
pass
1067
1068
class SparkPipeline:
1069
"""Spark ML Pipeline type."""
1070
pass
1071
1072
class TransformersPipeline:
1073
"""Hugging Face Transformers Pipeline type."""
1074
pass
1075
1076
# Serialization format constants
1077
SERIALIZATION_FORMAT_PICKLE = "pickle"
1078
SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle"
1079
SERIALIZATION_FORMAT_JSON = "json"
1080
```