0
# Model Management
1
2
MLflow's model management capabilities provide comprehensive model lifecycle support including logging, loading, evaluation, deployment, and registry operations. The system supports multiple ML frameworks with a universal model format and deployment across various platforms.
3
4
## Capabilities
5
6
### Model Logging and Loading
7
8
Core functions for saving and loading models with support for multiple ML frameworks and custom model formats.
9
10
```python { .api }
11
def log_model(model, artifact_path, registered_model_name=None, signature=None, input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, conda_env=None, extra_conda_requirements=None, metadata=None, **kwargs):
12
"""
13
Log machine learning model as MLflow artifact.
14
15
Parameters:
16
- model: Model object - The model to be logged
17
- artifact_path: str - Relative artifact path within run
18
- registered_model_name: str, optional - Name for model registry
19
- signature: ModelSignature, optional - Model input/output signature
20
- input_example: Any, optional - Example model input for inference
21
- await_registration_for: int - Seconds to wait for registry completion
22
- pip_requirements: list, optional - List of pip package requirements
23
- extra_pip_requirements: list, optional - Additional pip requirements
24
- conda_env: str or dict, optional - Conda environment specification
25
- extra_conda_requirements: list, optional - Additional conda requirements
26
- metadata: dict, optional - Custom model metadata
27
28
Returns:
29
ModelInfo object with logged model details
30
"""
31
32
def load_model(model_uri, dst_path=None, **kwargs):
33
"""
34
Load MLflow model from URI.
35
36
Parameters:
37
- model_uri: str - URI pointing to MLflow model
38
- dst_path: str, optional - Local destination for model artifacts
39
- kwargs: Additional framework-specific arguments
40
41
Returns:
42
Loaded model object ready for inference
43
"""
44
45
def predict(model_uri, input_data, content_type=None, json_format=None, **kwargs):
46
"""
47
Generate predictions using MLflow model.
48
49
Parameters:
50
- model_uri: str - URI pointing to MLflow model
51
- input_data: DataFrame, array, or dict - Input data for predictions
52
- content_type: str, optional - Input data content type
53
- json_format: str, optional - JSON serialization format
54
- kwargs: Additional prediction arguments
55
56
Returns:
57
Predictions in framework-specific format
58
"""
59
60
def get_model_info(model_uri):
61
"""
62
Get comprehensive model information.
63
64
Parameters:
65
- model_uri: str - URI pointing to MLflow model
66
67
Returns:
68
ModelInfo object with model metadata and signature
69
"""
70
71
def set_model(model):
72
"""
73
Set active model in current context.
74
75
Parameters:
76
- model: Model object - Model to set as active
77
"""
78
79
def update_model_requirements(model_uri, requirements_file_path):
80
"""
81
Update model requirements from file.
82
83
Parameters:
84
- model_uri: str - URI pointing to MLflow model
85
- requirements_file_path: str - Path to requirements file
86
"""
87
```
88
89
### Model Evaluation
90
91
Comprehensive model evaluation framework with built-in metrics, custom evaluators, and automated assessment capabilities.
92
93
```python { .api }
94
def evaluate(model=None, data=None, targets=None, model_type=None, evaluators=None, evaluator_config=None, custom_metrics=None, extra_metrics=None, custom_artifacts=None, baseline_model=None, env_manager=None, model_config=None, baseline_config=None, inference_params=None, baseline_inference_params=None):
95
"""
96
Evaluate model performance with comprehensive metrics.
97
98
Parameters:
99
- model: Model, callable, or URI - Model to evaluate
100
- data: DataFrame, array, or URI - Evaluation dataset
101
- targets: str or array, optional - Target column name or values
102
- model_type: str, optional - Type of model (classifier, regressor, etc.)
103
- evaluators: list, optional - List of evaluator names or objects
104
- evaluator_config: dict, optional - Configuration for evaluators
105
- custom_metrics: list, optional - Custom metric functions
106
- extra_metrics: list, optional - Additional built-in metrics
107
- custom_artifacts: list, optional - Custom artifact generators
108
- baseline_model: Model or URI, optional - Baseline for comparison
109
- env_manager: str, optional - Environment management method
110
- model_config: dict, optional - Model configuration parameters
111
- baseline_config: dict, optional - Baseline model configuration
112
- inference_params: dict, optional - Model inference parameters
113
- baseline_inference_params: dict, optional - Baseline inference parameters
114
115
Returns:
116
EvaluationResult object with metrics and artifacts
117
"""
118
119
def list_evaluators():
120
"""
121
List available built-in evaluators.
122
123
Returns:
124
List of evaluator names and descriptions
125
"""
126
127
def make_metric(eval_fn, greater_is_better=True, name=None, long_name=None, version=None, metric_details=None, metric_metadata=None, genai_metric_args=None):
128
"""
129
Create custom evaluation metric.
130
131
Parameters:
132
- eval_fn: callable - Function that computes metric
133
- greater_is_better: bool - Whether higher values are better
134
- name: str, optional - Metric name (inferred if not provided)
135
- long_name: str, optional - Human-readable metric name
136
- version: str, optional - Metric version
137
- metric_details: str, optional - Metric description
138
- metric_metadata: dict, optional - Additional metadata
139
- genai_metric_args: dict, optional - GenAI-specific arguments
140
141
Returns:
142
EvaluationMetric object
143
"""
144
145
def validate_evaluation_results(results):
146
"""
147
Validate evaluation results structure.
148
149
Parameters:
150
- results: EvaluationResult - Results to validate
151
152
Returns:
153
bool - Whether results are valid
154
"""
155
```
156
157
### Model Signature and Schema
158
159
Functions for defining and validating model input/output schemas and signatures for robust model interfaces.
160
161
```python { .api }
162
def infer_signature(model_input, model_output=None, params=None):
163
"""
164
Infer model signature from input/output examples.
165
166
Parameters:
167
- model_input: DataFrame, array, dict - Example model input
168
- model_output: array, dict, optional - Example model output
169
- params: dict, optional - Model parameters schema
170
171
Returns:
172
ModelSignature object describing input/output schema
173
"""
174
175
def set_signature(model_uri, signature):
176
"""
177
Set signature for existing model.
178
179
Parameters:
180
- model_uri: str - URI pointing to MLflow model
181
- signature: ModelSignature - Signature to set
182
"""
183
184
def validate_schema(input_data, expected_schema):
185
"""
186
Validate data against expected schema.
187
188
Parameters:
189
- input_data: DataFrame, array, dict - Data to validate
190
- expected_schema: Schema - Expected data schema
191
192
Returns:
193
bool - Whether data matches schema
194
"""
195
196
def validate_serving_input(input_data, expected_signature):
197
"""
198
Validate serving input against model signature.
199
200
Parameters:
201
- input_data: dict - Serving input data
202
- expected_signature: ModelSignature - Expected signature
203
204
Returns:
205
bool - Whether input is valid for serving
206
"""
207
208
def convert_input_example_to_serving_input(input_example, signature=None):
209
"""
210
Convert input example to serving format.
211
212
Parameters:
213
- input_example: Any - Model input example
214
- signature: ModelSignature, optional - Model signature
215
216
Returns:
217
dict - Input in serving API format
218
"""
219
```
220
221
### Model Registry Integration
222
223
Functions for registering models and managing model versions in the MLflow Model Registry.
224
225
```python { .api }
226
def register_model(model_uri, name, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, tags=None, **kwargs):
227
"""
228
Register model in MLflow Model Registry.
229
230
Parameters:
231
- model_uri: str - URI pointing to MLflow model
232
- name: str - Name of registered model
233
- await_registration_for: int - Seconds to wait for completion
234
- tags: dict, optional - Tags for model version
235
- kwargs: Additional registration arguments
236
237
Returns:
238
ModelVersion object representing registered version
239
"""
240
241
def add_libraries_to_model(model_uri, run_id=None, registered_model_name=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS):
242
"""
243
Add current environment libraries to model.
244
245
Parameters:
246
- model_uri: str - URI pointing to MLflow model
247
- run_id: str, optional - Run ID for model artifacts
248
- registered_model_name: str, optional - Name for registration
249
- await_registration_for: int - Seconds to wait for completion
250
251
Returns:
252
Updated model with library dependencies
253
"""
254
```
255
256
### Model Configuration and Resources
257
258
Classes and functions for managing model configuration, resources, and deployment requirements.
259
260
```python { .api }
261
class ModelConfig:
262
def __init__(self, development_config=None, **kwargs):
263
"""
264
Model configuration for serving and deployment.
265
266
Parameters:
267
- development_config: dict, optional - Development-specific config
268
- kwargs: Additional configuration parameters
269
"""
270
271
class Resource:
272
def __init__(self, name, resource_type, config=None):
273
"""
274
Model resource specification.
275
276
Parameters:
277
- name: str - Resource name
278
- resource_type: ResourceType - Type of resource
279
- config: dict, optional - Resource configuration
280
"""
281
282
class ResourceType:
283
"""Enumeration of supported resource types."""
284
DATABRICKS_SERVING_ENDPOINT = "databricks_serving_endpoint"
285
DATABRICKS_VECTOR_SEARCH_INDEX = "databricks_vector_search_index"
286
```
287
288
### Deployment and Serving
289
290
Functions for building and deploying models to various serving platforms and environments.
291
292
```python { .api }
293
def build_docker(model_uri, name=None, env_manager=None, mlflow_home=None, install_java=False, install_mlflow=False, enable_mlserver=False, base_image=None):
294
"""
295
Build Docker image for model serving.
296
297
Parameters:
298
- model_uri: str - URI pointing to MLflow model
299
- name: str, optional - Docker image name
300
- env_manager: str, optional - Environment manager (conda, virtualenv)
301
- mlflow_home: str, optional - MLflow installation path
302
- install_java: bool - Whether to install Java runtime
303
- install_mlflow: bool - Whether to install MLflow in image
304
- enable_mlserver: bool - Whether to use MLServer for serving
305
- base_image: str, optional - Base Docker image
306
307
Returns:
308
str - Built Docker image name
309
"""
310
```
311
312
### Model Input Examples
313
314
Utilities for managing and validating model input examples for testing and documentation.
315
316
```python { .api }
317
class ModelInputExample:
318
def __init__(self, input_example):
319
"""
320
Container for model input example.
321
322
Parameters:
323
- input_example: Any - Example input data
324
"""
325
```
326
327
## Usage Examples
328
329
### Basic Model Logging and Loading
330
331
```python
332
import mlflow
333
import mlflow.sklearn
334
from sklearn.ensemble import RandomForestClassifier
335
from sklearn.datasets import make_classification
336
from sklearn.model_selection import train_test_split
337
338
# Generate sample data
339
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
340
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
341
342
# Train model
343
model = RandomForestClassifier(n_estimators=100, random_state=42)
344
model.fit(X_train, y_train)
345
346
# Log model with signature and input example
347
with mlflow.start_run():
348
# Infer model signature
349
signature = mlflow.models.infer_signature(X_train, model.predict(X_train))
350
351
# Log model with metadata
352
mlflow.sklearn.log_model(
353
sk_model=model,
354
artifact_path="model",
355
signature=signature,
356
input_example=X_train[:3],
357
registered_model_name="random-forest-classifier",
358
metadata={"algorithm": "RandomForest", "framework": "scikit-learn"}
359
)
360
361
# Get model info
362
model_uri = mlflow.get_artifact_uri("model")
363
model_info = mlflow.models.get_model_info(model_uri)
364
print(f"Model signature: {model_info.signature}")
365
366
# Load model for inference
367
loaded_model = mlflow.sklearn.load_model(model_uri)
368
predictions = loaded_model.predict(X_test)
369
```
370
371
### Model Evaluation
372
373
```python
374
import mlflow
375
import pandas as pd
376
from mlflow.models import evaluate
377
378
# Prepare evaluation data
379
eval_data = pd.DataFrame(X_test)
380
eval_data['target'] = y_test
381
382
# Evaluate model with built-in metrics
383
with mlflow.start_run():
384
results = evaluate(
385
model=model_uri,
386
data=eval_data,
387
targets="target",
388
model_type="classifier",
389
evaluators=["default"],
390
evaluator_config={
391
"pos_label": 1,
392
"average": "weighted"
393
}
394
)
395
396
# Print evaluation results
397
print("Evaluation metrics:")
398
for metric_name, metric_value in results.metrics.items():
399
print(f"{metric_name}: {metric_value}")
400
401
# Log evaluation results
402
mlflow.log_metrics(results.metrics)
403
```
404
405
### Custom Evaluation Metrics
406
407
```python
408
import mlflow
409
from mlflow.models import make_metric, evaluate
410
import numpy as np
411
412
# Define custom metric
413
def balanced_accuracy(eval_df, builtin_metrics):
414
"""Custom balanced accuracy metric."""
415
y_true = eval_df["target"]
416
y_pred = eval_df["prediction"]
417
418
# Calculate balanced accuracy
419
from sklearn.metrics import balanced_accuracy_score
420
return balanced_accuracy_score(y_true, y_pred)
421
422
# Create metric object
423
balanced_acc_metric = make_metric(
424
eval_fn=balanced_accuracy,
425
greater_is_better=True,
426
name="balanced_accuracy",
427
long_name="Balanced Accuracy Score"
428
)
429
430
# Evaluate with custom metric
431
results = evaluate(
432
model=model_uri,
433
data=eval_data,
434
targets="target",
435
model_type="classifier",
436
extra_metrics=[balanced_acc_metric]
437
)
438
439
print(f"Balanced accuracy: {results.metrics['balanced_accuracy']}")
440
```
441
442
### Model Comparison and Baseline
443
444
```python
445
import mlflow
446
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
447
448
# Train baseline model
449
baseline_model = GradientBoostingClassifier(random_state=42)
450
baseline_model.fit(X_train, y_train)
451
452
# Log baseline model
453
with mlflow.start_run(run_name="baseline_model"):
454
mlflow.sklearn.log_model(baseline_model, "model")
455
baseline_uri = mlflow.get_artifact_uri("model")
456
457
# Train candidate model
458
candidate_model = RandomForestClassifier(n_estimators=200, random_state=42)
459
candidate_model.fit(X_train, y_train)
460
461
# Log candidate model
462
with mlflow.start_run(run_name="candidate_model"):
463
mlflow.sklearn.log_model(candidate_model, "model")
464
candidate_uri = mlflow.get_artifact_uri("model")
465
466
# Compare models
467
with mlflow.start_run(run_name="model_comparison"):
468
results = evaluate(
469
model=candidate_uri,
470
data=eval_data,
471
targets="target",
472
model_type="classifier",
473
baseline_model=baseline_uri,
474
evaluators=["default"]
475
)
476
477
# Log comparison results
478
for metric_name, metric_value in results.metrics.items():
479
mlflow.log_metric(metric_name, metric_value)
480
```
481
482
### Model Serving with Docker
483
484
```python
485
import mlflow.models
486
487
# Build Docker image for model serving
488
model_uri = "runs:/abc123/model"
489
docker_image = mlflow.models.build_docker(
490
model_uri=model_uri,
491
name="my-model-serving",
492
env_manager="conda",
493
enable_mlserver=True
494
)
495
496
print(f"Built Docker image: {docker_image}")
497
498
# The image can now be deployed to container platforms
499
# docker run -p 5000:8080 my-model-serving
500
```
501
502
### Advanced Model Configuration
503
504
```python
505
import mlflow
506
from mlflow.models import ModelConfig, Resource, ResourceType
507
508
# Define model configuration
509
model_config = ModelConfig(
510
development_config={
511
"batch_size": 32,
512
"max_sequence_length": 512
513
}
514
)
515
516
# Define model resources
517
vector_search_resource = Resource(
518
name="product_vectors",
519
resource_type=ResourceType.DATABRICKS_VECTOR_SEARCH_INDEX,
520
config={
521
"endpoint_name": "vector_search_endpoint",
522
"index_name": "product_embeddings"
523
}
524
)
525
526
# Log model with configuration and resources
527
with mlflow.start_run():
528
mlflow.transformers.log_model(
529
transformers_model=model,
530
artifact_path="model",
531
model_config=model_config,
532
resources=[vector_search_resource]
533
)
534
```
535
536
## Types
537
538
```python { .api }
539
from mlflow.models.model import Model, ModelInfo
540
from mlflow.models.signature import ModelSignature
541
from mlflow.models.evaluation import EvaluationResult, EvaluationMetric, EvaluationArtifact, MetricThreshold
542
from mlflow.models.utils import ModelInputExample
543
from mlflow.models.model_config import ModelConfig
544
from mlflow.models.resources import Resource, ResourceType
545
from mlflow.entities.model_registry import ModelVersion, RegisteredModel
546
547
class ModelInfo:
548
artifact_path: str
549
flavors: Dict[str, Any]
550
model_size_bytes: int
551
model_uuid: str
552
run_id: str
553
saved_input_example_info: Dict[str, Any]
554
signature: ModelSignature
555
utc_time_created: str
556
mlflow_version: str
557
metadata: Dict[str, Any]
558
559
class ModelSignature:
560
inputs: Schema
561
outputs: Schema
562
params: ParamsSchema
563
564
class Schema:
565
inputs: List[ColSpec]
566
567
class ColSpec:
568
type: DataType
569
name: str
570
required: bool
571
572
class EvaluationResult:
573
metrics: Dict[str, float]
574
artifacts: Dict[str, str]
575
run_id: str
576
baseline_model_metrics: Dict[str, float]
577
578
class EvaluationMetric:
579
name: str
580
long_name: str
581
version: str
582
metric_details: str
583
greater_is_better: bool
584
eval_fn: callable
585
586
class EvaluationArtifact:
587
uri: str
588
content: Any
589
590
class MetricThreshold:
591
threshold: float
592
greater_is_better: bool
593
594
class ModelInputExample:
595
input_example: Any
596
597
class ModelConfig:
598
development_config: Dict[str, Any]
599
600
class Resource:
601
name: str
602
resource_type: ResourceType
603
config: Dict[str, Any]
604
605
class ResourceType:
606
DATABRICKS_SERVING_ENDPOINT: str
607
DATABRICKS_VECTOR_SEARCH_INDEX: str
608
609
class Model:
610
artifact_path: str
611
flavors: Dict[str, Any]
612
model_uuid: str
613
mlflow_version: str
614
saved_input_example_info: Dict[str, Any]
615
signature: ModelSignature
616
utc_time_created: str
617
```