0
# Models
1
2
The Model Control Plane provides a centralized model namespace for organizing artifacts, metadata, and versions. It enables tracking model evolution, linking artifacts, and managing model lifecycle stages.
3
4
## Capabilities
5
6
### Model Class
7
8
Model configuration for grouping artifacts and metadata.
9
10
```python { .api }
11
class Model:
12
"""
13
Model configuration for grouping artifacts and metadata.
14
15
Used in pipeline or step decorators to associate runs with a model
16
namespace in the Model Control Plane.
17
18
Attributes:
19
- name: Model name (required)
20
- version: Model version or stage (e.g., "1.0.0", "production", "staging")
21
- license: Model license (e.g., "Apache-2.0", "MIT")
22
- description: Model description
23
- audience: Target audience (e.g., "Data Scientists", "ML Engineers")
24
- use_cases: Use cases description
25
- limitations: Known limitations
26
- trade_offs: Trade-offs made in model design
27
- ethics: Ethical considerations
28
- tags: List of tag names
29
- save_models_to_registry: Auto-save to model registry (default: True)
30
- suppress_class_validation_warnings: Suppress validation warnings
31
"""
32
33
def __init__(
34
self,
35
name: str,
36
version: str = None,
37
license: str = None,
38
description: str = None,
39
audience: str = None,
40
use_cases: str = None,
41
limitations: str = None,
42
trade_offs: str = None,
43
ethics: str = None,
44
tags: list = None,
45
save_models_to_registry: bool = True,
46
suppress_class_validation_warnings: bool = False
47
):
48
"""
49
Initialize Model configuration.
50
51
Parameters:
52
- name: Model name (required)
53
- version: Model version or stage name
54
- license: License identifier
55
- description: Detailed model description
56
- audience: Target audience
57
- use_cases: Intended use cases
58
- limitations: Known limitations
59
- trade_offs: Design trade-offs
60
- ethics: Ethical considerations
61
- tags: List of tags
62
- save_models_to_registry: Whether to auto-save to registry
63
- suppress_class_validation_warnings: Suppress warnings
64
65
Example:
66
```python
67
from zenml import pipeline, Model
68
69
model = Model(
70
name="sentiment_classifier",
71
version="1.0.0",
72
license="Apache-2.0",
73
description="BERT-based sentiment classifier",
74
audience="Data Scientists, ML Engineers",
75
use_cases="Customer feedback analysis, social media monitoring",
76
limitations="English language only, max 512 tokens",
77
trade_offs="Accuracy vs inference speed",
78
ethics="May exhibit bias on certain demographic groups",
79
tags=["nlp", "classification", "bert"]
80
)
81
82
@pipeline(model=model)
83
def training_pipeline():
84
# Pipeline steps
85
pass
86
```
87
"""
88
```
89
90
Import from:
91
92
```python
93
from zenml import Model
94
```
95
96
### Model Stages Enum
97
98
```python { .api }
99
class ModelStages(str, Enum):
100
"""
101
Model lifecycle stages.
102
103
Values:
104
- NONE: No specific stage
105
- STAGING: Model in staging environment
106
- PRODUCTION: Model in production
107
- ARCHIVED: Archived model
108
- LATEST: Latest model version (special marker)
109
"""
110
NONE = "none"
111
STAGING = "staging"
112
PRODUCTION = "production"
113
ARCHIVED = "archived"
114
LATEST = "latest"
115
```
116
117
Import from:
118
119
```python
120
from zenml.enums import ModelStages
121
```
122
123
### Log Model Metadata
124
125
Log metadata for a model version.
126
127
```python { .api }
128
def log_model_metadata(
129
metadata: dict,
130
model_name: str = None,
131
model_version: str = None
132
):
133
"""
134
Log metadata for a model version.
135
136
Can be called within a pipeline/step to attach metadata to the
137
configured model, or called outside to attach metadata to any model.
138
139
Parameters:
140
- metadata: Metadata dict to log (keys must be strings)
141
- model_name: Model name (uses current context if None)
142
- model_version: Model version (uses current context if None)
143
144
Example:
145
```python
146
from zenml import step, log_model_metadata
147
148
@step
149
def evaluate_model(model: dict, test_data: list) -> float:
150
accuracy = 0.95
151
152
# Log evaluation metrics as model metadata
153
log_model_metadata(
154
metadata={
155
"test_accuracy": accuracy,
156
"test_samples": len(test_data),
157
"test_date": "2024-01-15"
158
}
159
)
160
161
return accuracy
162
163
# Log metadata outside pipeline
164
from zenml import log_model_metadata
165
166
log_model_metadata(
167
metadata={
168
"production_ready": True,
169
"reviewer": "ml-team",
170
"approval_date": "2024-01-20"
171
},
172
model_name="sentiment_classifier",
173
model_version="1.0.0"
174
)
175
```
176
"""
177
```
178
179
Import from:
180
181
```python
182
from zenml import log_model_metadata
183
```
184
185
### Link Artifact to Model
186
187
Link an artifact to a model version.
188
189
```python { .api }
190
def link_artifact_to_model(
191
artifact_version,
192
model=None
193
):
194
"""
195
Link an artifact to a model version.
196
197
Creates an association between an artifact version and a model version,
198
useful for tracking model dependencies and related artifacts.
199
200
Parameters:
201
- artifact_version: ArtifactVersionResponse object to link
202
- model: Model object to link to (uses current context if None)
203
204
Raises:
205
RuntimeError: If called without model parameter and no model context exists
206
207
Example:
208
```python
209
from zenml import link_artifact_to_model, save_artifact, Model
210
from zenml.client import Client
211
212
# Within a step or pipeline with model context
213
artifact_version = save_artifact(data, name="preprocessor")
214
link_artifact_to_model(artifact_version) # Uses context model
215
216
# Outside step with explicit model
217
client = Client()
218
artifact = client.get_artifact_version("preprocessor", version="v1.0")
219
model = Model(name="sentiment_classifier", version="1.0.0")
220
link_artifact_to_model(artifact, model=model)
221
```
222
"""
223
```
224
225
Import from:
226
227
```python
228
from zenml import link_artifact_to_model
229
```
230
231
## Usage Examples
232
233
### Basic Model Configuration
234
235
```python
236
from zenml import pipeline, step, Model
237
238
# Define model configuration
239
model_config = Model(
240
name="fraud_detector",
241
version="1.0.0",
242
license="MIT",
243
description="XGBoost-based fraud detection model",
244
tags=["fraud", "xgboost", "production"]
245
)
246
247
@step
248
def train_model(data: list) -> dict:
249
"""Train fraud detection model."""
250
return {"model": "trained", "accuracy": 0.97}
251
252
@pipeline(model=model_config)
253
def fraud_detection_pipeline():
254
"""Pipeline with model tracking."""
255
data = [1, 2, 3, 4, 5]
256
model = train_model(data)
257
return model
258
259
if __name__ == "__main__":
260
fraud_detection_pipeline()
261
```
262
263
### Model with Comprehensive Metadata
264
265
```python
266
from zenml import pipeline, Model
267
268
model = Model(
269
name="recommendation_engine",
270
version="2.1.0",
271
license="Apache-2.0",
272
description=(
273
"Collaborative filtering recommendation engine using "
274
"matrix factorization with neural network embeddings"
275
),
276
audience="Product teams, ML engineers, data scientists",
277
use_cases=(
278
"E-commerce product recommendations, content personalization, "
279
"user similarity matching"
280
),
281
limitations=(
282
"Requires minimum 100 interactions per user for accurate recommendations. "
283
"Cold start problem for new users/items. English language content only."
284
),
285
trade_offs=(
286
"Increased model complexity for better accuracy results in higher "
287
"inference latency (50ms vs 20ms for simpler model)"
288
),
289
ethics=(
290
"May reinforce filter bubbles. Recommendations should be diversified. "
291
"Privacy considerations for user interaction data."
292
),
293
tags=["recommendations", "collaborative-filtering", "neural-network"]
294
)
295
296
@pipeline(model=model)
297
def recommendation_pipeline():
298
"""Build recommendation model."""
299
pass
300
```
301
302
### Using Model Stages
303
304
```python
305
from zenml import Model
306
from zenml.enums import ModelStages
307
308
# Reference production model
309
production_model = Model(
310
name="text_classifier",
311
version=ModelStages.PRODUCTION
312
)
313
314
# Reference staging model
315
staging_model = Model(
316
name="text_classifier",
317
version=ModelStages.STAGING
318
)
319
320
# Reference latest model
321
latest_model = Model(
322
name="text_classifier",
323
version=ModelStages.LATEST
324
)
325
```
326
327
### Logging Model Metadata
328
329
```python
330
from zenml import step, pipeline, Model, log_model_metadata
331
332
model_config = Model(name="image_classifier", version="3.0.0")
333
334
@step
335
def train_model(data: list) -> dict:
336
"""Train model."""
337
model = {"weights": [0.1, 0.2], "accuracy": 0.94}
338
339
# Log training metadata
340
log_model_metadata({
341
"training_samples": len(data),
342
"training_time": "3600s",
343
"optimizer": "adam",
344
"learning_rate": 0.001
345
})
346
347
return model
348
349
@step
350
def evaluate_model(model: dict, test_data: list) -> dict:
351
"""Evaluate model."""
352
metrics = {
353
"accuracy": 0.94,
354
"precision": 0.92,
355
"recall": 0.95,
356
"f1": 0.93
357
}
358
359
# Log evaluation metrics
360
log_model_metadata({
361
"test_accuracy": metrics["accuracy"],
362
"test_precision": metrics["precision"],
363
"test_recall": metrics["recall"],
364
"test_f1": metrics["f1"],
365
"test_samples": len(test_data)
366
})
367
368
return metrics
369
370
@pipeline(model=model_config)
371
def full_pipeline():
372
"""Training and evaluation pipeline."""
373
data = [1, 2, 3, 4, 5]
374
model = train_model(data)
375
metrics = evaluate_model(model, [6, 7, 8])
376
return metrics
377
```
378
379
### Managing Models with Client
380
381
```python
382
from zenml.client import Client
383
from zenml.enums import ModelStages
384
385
client = Client()
386
387
# Create model namespace
388
model = client.create_model(
389
name="customer_churn_predictor",
390
license="MIT",
391
description="Predicts customer churn probability",
392
tags=["churn", "classification"]
393
)
394
395
# Create model version
396
version = client.create_model_version(
397
model_name_or_id=model.id,
398
version="1.0.0",
399
description="Initial production release",
400
tags=["production", "v1"]
401
)
402
403
# Update model version stage
404
client.update_model_version(
405
model_name_or_id=model.id,
406
version_name_or_id=version.id,
407
stage=ModelStages.PRODUCTION
408
)
409
410
# List all model versions
411
versions = client.list_model_versions(model_name_or_id=model.id)
412
for v in versions:
413
print(f"Version: {v.version}, Stage: {v.stage}")
414
415
# Get model version by stage
416
prod_version = client.get_model_version(
417
model_name_or_id=model.name,
418
version=ModelStages.PRODUCTION
419
)
420
print(f"Production version: {prod_version.version}")
421
```
422
423
### Linking Artifacts to Models
424
425
```python
426
from zenml import step, pipeline, Model, save_artifact, link_artifact_to_model
427
from zenml.client import Client
428
429
model_config = Model(name="nlp_model", version="1.0.0")
430
431
@step
432
def create_preprocessor() -> dict:
433
"""Create text preprocessor."""
434
return {"tokenizer": "bert", "max_length": 512}
435
436
@pipeline(model=model_config)
437
def training_pipeline():
438
"""Pipeline that creates related artifacts."""
439
preprocessor = create_preprocessor()
440
return preprocessor
441
442
# Run pipeline
443
training_pipeline()
444
445
# Link external artifact to model
446
model = Model(name="nlp_model", version="1.0.0")
447
448
# Save additional artifact
449
vocab_artifact = save_artifact(
450
data={"vocab": ["hello", "world"], "size": 30000},
451
name="vocabulary"
452
)
453
454
# Link to model
455
link_artifact_to_model(
456
artifact_version=vocab_artifact,
457
model=model
458
)
459
460
# List model artifacts via client
461
client = Client()
462
model_version = client.get_model_version("nlp_model", version="1.0.0")
463
artifact_links = client.list_model_version_artifact_links(
464
model_version_id=model_version.id
465
)
466
for link in artifact_links:
467
print(f"Linked artifact: {link.artifact_name}")
468
```
469
470
### Model Versioning Strategy
471
472
```python
473
from zenml import pipeline, Model
474
from datetime import datetime
475
476
# Semantic versioning
477
model_v1 = Model(name="detector", version="1.0.0")
478
model_v1_1 = Model(name="detector", version="1.1.0")
479
model_v2 = Model(name="detector", version="2.0.0")
480
481
# Date-based versioning
482
model_dated = Model(
483
name="detector",
484
version=f"v{datetime.now().strftime('%Y%m%d')}"
485
)
486
487
# Stage-based (for inference pipelines)
488
model_prod = Model(name="detector", version="production")
489
model_staging = Model(name="detector", version="staging")
490
491
# Hash-based (for reproducibility)
492
model_hash = Model(
493
name="detector",
494
version="abc123def" # Git commit hash or data hash
495
)
496
```
497