ZenML is a unified MLOps framework that extends battle-tested machine learning operations principles to support the entire AI stack, from classical machine learning models to advanced AI agents.
The Model Control Plane provides a centralized model namespace for organizing artifacts, metadata, and versions. It enables tracking model evolution, linking artifacts, and managing model lifecycle stages.
Model configuration for grouping artifacts and metadata.
class Model:
"""
Model configuration for grouping artifacts and metadata.
Used in pipeline or step decorators to associate runs with a model
namespace in the Model Control Plane.
Attributes:
- name: Model name (required)
- version: Model version or stage (e.g., "1.0.0", "production", "staging")
- license: Model license (e.g., "Apache-2.0", "MIT")
- description: Model description
- audience: Target audience (e.g., "Data Scientists", "ML Engineers")
- use_cases: Use cases description
- limitations: Known limitations
- trade_offs: Trade-offs made in model design
- ethics: Ethical considerations
- tags: List of tag names
- save_models_to_registry: Auto-save to model registry (default: True)
- suppress_class_validation_warnings: Suppress validation warnings
"""
def __init__(
self,
name: str,
version: str = None,
license: str = None,
description: str = None,
audience: str = None,
use_cases: str = None,
limitations: str = None,
trade_offs: str = None,
ethics: str = None,
tags: list = None,
save_models_to_registry: bool = True,
suppress_class_validation_warnings: bool = False
):
"""
Initialize Model configuration.
Parameters:
- name: Model name (required)
- version: Model version or stage name
- license: License identifier
- description: Detailed model description
- audience: Target audience
- use_cases: Intended use cases
- limitations: Known limitations
- trade_offs: Design trade-offs
- ethics: Ethical considerations
- tags: List of tags
- save_models_to_registry: Whether to auto-save to registry
- suppress_class_validation_warnings: Suppress warnings
Example:
```python
from zenml import pipeline, Model
model = Model(
name="sentiment_classifier",
version="1.0.0",
license="Apache-2.0",
description="BERT-based sentiment classifier",
audience="Data Scientists, ML Engineers",
use_cases="Customer feedback analysis, social media monitoring",
limitations="English language only, max 512 tokens",
trade_offs="Accuracy vs inference speed",
ethics="May exhibit bias on certain demographic groups",
tags=["nlp", "classification", "bert"]
)
@pipeline(model=model)
def training_pipeline():
# Pipeline steps
pass
```
"""Import from:
from zenml import Modelclass ModelStages(str, Enum):
"""
Model lifecycle stages.
Values:
- NONE: No specific stage
- STAGING: Model in staging environment
- PRODUCTION: Model in production
- ARCHIVED: Archived model
- LATEST: Latest model version (special marker)
"""
NONE = "none"
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"
LATEST = "latest"Import from:
from zenml.enums import ModelStagesLog metadata for a model version.
def log_model_metadata(
metadata: dict,
model_name: str = None,
model_version: str = None
):
"""
Log metadata for a model version.
Can be called within a pipeline/step to attach metadata to the
configured model, or called outside to attach metadata to any model.
Parameters:
- metadata: Metadata dict to log (keys must be strings)
- model_name: Model name (uses current context if None)
- model_version: Model version (uses current context if None)
Example:
```python
from zenml import step, log_model_metadata
@step
def evaluate_model(model: dict, test_data: list) -> float:
accuracy = 0.95
# Log evaluation metrics as model metadata
log_model_metadata(
metadata={
"test_accuracy": accuracy,
"test_samples": len(test_data),
"test_date": "2024-01-15"
}
)
return accuracy
# Log metadata outside pipeline
from zenml import log_model_metadata
log_model_metadata(
metadata={
"production_ready": True,
"reviewer": "ml-team",
"approval_date": "2024-01-20"
},
model_name="sentiment_classifier",
model_version="1.0.0"
)
```
"""Import from:
from zenml import log_model_metadataLink an artifact to a model version.
def link_artifact_to_model(
artifact_version,
model=None
):
"""
Link an artifact to a model version.
Creates an association between an artifact version and a model version,
useful for tracking model dependencies and related artifacts.
Parameters:
- artifact_version: ArtifactVersionResponse object to link
- model: Model object to link to (uses current context if None)
Raises:
RuntimeError: If called without model parameter and no model context exists
Example:
```python
from zenml import link_artifact_to_model, save_artifact, Model
from zenml.client import Client
# Within a step or pipeline with model context
artifact_version = save_artifact(data, name="preprocessor")
link_artifact_to_model(artifact_version) # Uses context model
# Outside step with explicit model
client = Client()
artifact = client.get_artifact_version("preprocessor", version="v1.0")
model = Model(name="sentiment_classifier", version="1.0.0")
link_artifact_to_model(artifact, model=model)
```
"""Import from:
from zenml import link_artifact_to_modelfrom zenml import pipeline, step, Model
# Define model configuration
model_config = Model(
name="fraud_detector",
version="1.0.0",
license="MIT",
description="XGBoost-based fraud detection model",
tags=["fraud", "xgboost", "production"]
)
@step
def train_model(data: list) -> dict:
"""Train fraud detection model."""
return {"model": "trained", "accuracy": 0.97}
@pipeline(model=model_config)
def fraud_detection_pipeline():
"""Pipeline with model tracking."""
data = [1, 2, 3, 4, 5]
model = train_model(data)
return model
if __name__ == "__main__":
fraud_detection_pipeline()from zenml import pipeline, Model
model = Model(
name="recommendation_engine",
version="2.1.0",
license="Apache-2.0",
description=(
"Collaborative filtering recommendation engine using "
"matrix factorization with neural network embeddings"
),
audience="Product teams, ML engineers, data scientists",
use_cases=(
"E-commerce product recommendations, content personalization, "
"user similarity matching"
),
limitations=(
"Requires minimum 100 interactions per user for accurate recommendations. "
"Cold start problem for new users/items. English language content only."
),
trade_offs=(
"Increased model complexity for better accuracy results in higher "
"inference latency (50ms vs 20ms for simpler model)"
),
ethics=(
"May reinforce filter bubbles. Recommendations should be diversified. "
"Privacy considerations for user interaction data."
),
tags=["recommendations", "collaborative-filtering", "neural-network"]
)
@pipeline(model=model)
def recommendation_pipeline():
"""Build recommendation model."""
passfrom zenml import Model
from zenml.enums import ModelStages
# Reference production model
production_model = Model(
name="text_classifier",
version=ModelStages.PRODUCTION
)
# Reference staging model
staging_model = Model(
name="text_classifier",
version=ModelStages.STAGING
)
# Reference latest model
latest_model = Model(
name="text_classifier",
version=ModelStages.LATEST
)from zenml import step, pipeline, Model, log_model_metadata
model_config = Model(name="image_classifier", version="3.0.0")
@step
def train_model(data: list) -> dict:
"""Train model."""
model = {"weights": [0.1, 0.2], "accuracy": 0.94}
# Log training metadata
log_model_metadata({
"training_samples": len(data),
"training_time": "3600s",
"optimizer": "adam",
"learning_rate": 0.001
})
return model
@step
def evaluate_model(model: dict, test_data: list) -> dict:
"""Evaluate model."""
metrics = {
"accuracy": 0.94,
"precision": 0.92,
"recall": 0.95,
"f1": 0.93
}
# Log evaluation metrics
log_model_metadata({
"test_accuracy": metrics["accuracy"],
"test_precision": metrics["precision"],
"test_recall": metrics["recall"],
"test_f1": metrics["f1"],
"test_samples": len(test_data)
})
return metrics
@pipeline(model=model_config)
def full_pipeline():
"""Training and evaluation pipeline."""
data = [1, 2, 3, 4, 5]
model = train_model(data)
metrics = evaluate_model(model, [6, 7, 8])
return metricsfrom zenml.client import Client
from zenml.enums import ModelStages
client = Client()
# Create model namespace
model = client.create_model(
name="customer_churn_predictor",
license="MIT",
description="Predicts customer churn probability",
tags=["churn", "classification"]
)
# Create model version
version = client.create_model_version(
model_name_or_id=model.id,
version="1.0.0",
description="Initial production release",
tags=["production", "v1"]
)
# Update model version stage
client.update_model_version(
model_name_or_id=model.id,
version_name_or_id=version.id,
stage=ModelStages.PRODUCTION
)
# List all model versions
versions = client.list_model_versions(model_name_or_id=model.id)
for v in versions:
print(f"Version: {v.version}, Stage: {v.stage}")
# Get model version by stage
prod_version = client.get_model_version(
model_name_or_id=model.name,
version=ModelStages.PRODUCTION
)
print(f"Production version: {prod_version.version}")from zenml import step, pipeline, Model, save_artifact, link_artifact_to_model
from zenml.client import Client
model_config = Model(name="nlp_model", version="1.0.0")
@step
def create_preprocessor() -> dict:
"""Create text preprocessor."""
return {"tokenizer": "bert", "max_length": 512}
@pipeline(model=model_config)
def training_pipeline():
"""Pipeline that creates related artifacts."""
preprocessor = create_preprocessor()
return preprocessor
# Run pipeline
training_pipeline()
# Link external artifact to model
model = Model(name="nlp_model", version="1.0.0")
# Save additional artifact
vocab_artifact = save_artifact(
data={"vocab": ["hello", "world"], "size": 30000},
name="vocabulary"
)
# Link to model
link_artifact_to_model(
artifact_version=vocab_artifact,
model=model
)
# List model artifacts via client
client = Client()
model_version = client.get_model_version("nlp_model", version="1.0.0")
artifact_links = client.list_model_version_artifact_links(
model_version_id=model_version.id
)
for link in artifact_links:
print(f"Linked artifact: {link.artifact_name}")from zenml import pipeline, Model
from datetime import datetime
# Semantic versioning
model_v1 = Model(name="detector", version="1.0.0")
model_v1_1 = Model(name="detector", version="1.1.0")
model_v2 = Model(name="detector", version="2.0.0")
# Date-based versioning
model_dated = Model(
name="detector",
version=f"v{datetime.now().strftime('%Y%m%d')}"
)
# Stage-based (for inference pipelines)
model_prod = Model(name="detector", version="production")
model_staging = Model(name="detector", version="staging")
# Hash-based (for reproducibility)
model_hash = Model(
name="detector",
version="abc123def" # Git commit hash or data hash
)Install with Tessl CLI
npx tessl i tessl/pypi-zenml