or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

artifact-config.mdartifacts.mdclient.mdconfig.mdenums.mdexceptions.mdhooks.mdindex.mdintegrations.mdmaterializers.mdmetadata-tags.mdmodels.mdpipelines-and-steps.mdpydantic-models.mdservices.mdstack-components.mdstacks.mdtypes.mdutilities.md

models.mddocs/

0

# Models

1

2

The Model Control Plane provides a centralized model namespace for organizing artifacts, metadata, and versions. It enables tracking model evolution, linking artifacts, and managing model lifecycle stages.

3

4

## Capabilities

5

6

### Model Class

7

8

Model configuration for grouping artifacts and metadata.

9

10

```python { .api }

11

class Model:

12

"""

13

Model configuration for grouping artifacts and metadata.

14

15

Used in pipeline or step decorators to associate runs with a model

16

namespace in the Model Control Plane.

17

18

Attributes:

19

- name: Model name (required)

20

- version: Model version or stage (e.g., "1.0.0", "production", "staging")

21

- license: Model license (e.g., "Apache-2.0", "MIT")

22

- description: Model description

23

- audience: Target audience (e.g., "Data Scientists", "ML Engineers")

24

- use_cases: Use cases description

25

- limitations: Known limitations

26

- trade_offs: Trade-offs made in model design

27

- ethics: Ethical considerations

28

- tags: List of tag names

29

- save_models_to_registry: Auto-save to model registry (default: True)

30

- suppress_class_validation_warnings: Suppress validation warnings

31

"""

32

33

def __init__(

34

self,

35

name: str,

36

version: str = None,

37

license: str = None,

38

description: str = None,

39

audience: str = None,

40

use_cases: str = None,

41

limitations: str = None,

42

trade_offs: str = None,

43

ethics: str = None,

44

tags: list = None,

45

save_models_to_registry: bool = True,

46

suppress_class_validation_warnings: bool = False

47

):

48

"""

49

Initialize Model configuration.

50

51

Parameters:

52

- name: Model name (required)

53

- version: Model version or stage name

54

- license: License identifier

55

- description: Detailed model description

56

- audience: Target audience

57

- use_cases: Intended use cases

58

- limitations: Known limitations

59

- trade_offs: Design trade-offs

60

- ethics: Ethical considerations

61

- tags: List of tags

62

- save_models_to_registry: Whether to auto-save to registry

63

- suppress_class_validation_warnings: Suppress warnings

64

65

Example:

66

```python

67

from zenml import pipeline, Model

68

69

model = Model(

70

name="sentiment_classifier",

71

version="1.0.0",

72

license="Apache-2.0",

73

description="BERT-based sentiment classifier",

74

audience="Data Scientists, ML Engineers",

75

use_cases="Customer feedback analysis, social media monitoring",

76

limitations="English language only, max 512 tokens",

77

trade_offs="Accuracy vs inference speed",

78

ethics="May exhibit bias on certain demographic groups",

79

tags=["nlp", "classification", "bert"]

80

)

81

82

@pipeline(model=model)

83

def training_pipeline():

84

# Pipeline steps

85

pass

86

```

87

"""

88

```

89

90

Import from:

91

92

```python

93

from zenml import Model

94

```

95

96

### Model Stages Enum

97

98

```python { .api }

99

class ModelStages(str, Enum):

100

"""

101

Model lifecycle stages.

102

103

Values:

104

- NONE: No specific stage

105

- STAGING: Model in staging environment

106

- PRODUCTION: Model in production

107

- ARCHIVED: Archived model

108

- LATEST: Latest model version (special marker)

109

"""

110

NONE = "none"

111

STAGING = "staging"

112

PRODUCTION = "production"

113

ARCHIVED = "archived"

114

LATEST = "latest"

115

```

116

117

Import from:

118

119

```python

120

from zenml.enums import ModelStages

121

```

122

123

### Log Model Metadata

124

125

Log metadata for a model version.

126

127

```python { .api }

128

def log_model_metadata(

129

metadata: dict,

130

model_name: str = None,

131

model_version: str = None

132

):

133

"""

134

Log metadata for a model version.

135

136

Can be called within a pipeline/step to attach metadata to the

137

configured model, or called outside to attach metadata to any model.

138

139

Parameters:

140

- metadata: Metadata dict to log (keys must be strings)

141

- model_name: Model name (uses current context if None)

142

- model_version: Model version (uses current context if None)

143

144

Example:

145

```python

146

from zenml import step, log_model_metadata

147

148

@step

149

def evaluate_model(model: dict, test_data: list) -> float:

150

accuracy = 0.95

151

152

# Log evaluation metrics as model metadata

153

log_model_metadata(

154

metadata={

155

"test_accuracy": accuracy,

156

"test_samples": len(test_data),

157

"test_date": "2024-01-15"

158

}

159

)

160

161

return accuracy

162

163

# Log metadata outside pipeline

164

from zenml import log_model_metadata

165

166

log_model_metadata(

167

metadata={

168

"production_ready": True,

169

"reviewer": "ml-team",

170

"approval_date": "2024-01-20"

171

},

172

model_name="sentiment_classifier",

173

model_version="1.0.0"

174

)

175

```

176

"""

177

```

178

179

Import from:

180

181

```python

182

from zenml import log_model_metadata

183

```

184

185

### Link Artifact to Model

186

187

Link an artifact to a model version.

188

189

```python { .api }

190

def link_artifact_to_model(

191

artifact_version,

192

model=None

193

):

194

"""

195

Link an artifact to a model version.

196

197

Creates an association between an artifact version and a model version,

198

useful for tracking model dependencies and related artifacts.

199

200

Parameters:

201

- artifact_version: ArtifactVersionResponse object to link

202

- model: Model object to link to (uses current context if None)

203

204

Raises:

205

RuntimeError: If called without model parameter and no model context exists

206

207

Example:

208

```python

209

from zenml import link_artifact_to_model, save_artifact, Model

210

from zenml.client import Client

211

212

# Within a step or pipeline with model context

213

artifact_version = save_artifact(data, name="preprocessor")

214

link_artifact_to_model(artifact_version) # Uses context model

215

216

# Outside step with explicit model

217

client = Client()

218

artifact = client.get_artifact_version("preprocessor", version="v1.0")

219

model = Model(name="sentiment_classifier", version="1.0.0")

220

link_artifact_to_model(artifact, model=model)

221

```

222

"""

223

```

224

225

Import from:

226

227

```python

228

from zenml import link_artifact_to_model

229

```

230

231

## Usage Examples

232

233

### Basic Model Configuration

234

235

```python

236

from zenml import pipeline, step, Model

237

238

# Define model configuration

239

model_config = Model(

240

name="fraud_detector",

241

version="1.0.0",

242

license="MIT",

243

description="XGBoost-based fraud detection model",

244

tags=["fraud", "xgboost", "production"]

245

)

246

247

@step

248

def train_model(data: list) -> dict:

249

"""Train fraud detection model."""

250

return {"model": "trained", "accuracy": 0.97}

251

252

@pipeline(model=model_config)

253

def fraud_detection_pipeline():

254

"""Pipeline with model tracking."""

255

data = [1, 2, 3, 4, 5]

256

model = train_model(data)

257

return model

258

259

if __name__ == "__main__":

260

fraud_detection_pipeline()

261

```

262

263

### Model with Comprehensive Metadata

264

265

```python

266

from zenml import pipeline, Model

267

268

model = Model(

269

name="recommendation_engine",

270

version="2.1.0",

271

license="Apache-2.0",

272

description=(

273

"Collaborative filtering recommendation engine using "

274

"matrix factorization with neural network embeddings"

275

),

276

audience="Product teams, ML engineers, data scientists",

277

use_cases=(

278

"E-commerce product recommendations, content personalization, "

279

"user similarity matching"

280

),

281

limitations=(

282

"Requires minimum 100 interactions per user for accurate recommendations. "

283

"Cold start problem for new users/items. English language content only."

284

),

285

trade_offs=(

286

"Increased model complexity for better accuracy results in higher "

287

"inference latency (50ms vs 20ms for simpler model)"

288

),

289

ethics=(

290

"May reinforce filter bubbles. Recommendations should be diversified. "

291

"Privacy considerations for user interaction data."

292

),

293

tags=["recommendations", "collaborative-filtering", "neural-network"]

294

)

295

296

@pipeline(model=model)

297

def recommendation_pipeline():

298

"""Build recommendation model."""

299

pass

300

```

301

302

### Using Model Stages

303

304

```python

305

from zenml import Model

306

from zenml.enums import ModelStages

307

308

# Reference production model

309

production_model = Model(

310

name="text_classifier",

311

version=ModelStages.PRODUCTION

312

)

313

314

# Reference staging model

315

staging_model = Model(

316

name="text_classifier",

317

version=ModelStages.STAGING

318

)

319

320

# Reference latest model

321

latest_model = Model(

322

name="text_classifier",

323

version=ModelStages.LATEST

324

)

325

```

326

327

### Logging Model Metadata

328

329

```python

330

from zenml import step, pipeline, Model, log_model_metadata

331

332

model_config = Model(name="image_classifier", version="3.0.0")

333

334

@step

335

def train_model(data: list) -> dict:

336

"""Train model."""

337

model = {"weights": [0.1, 0.2], "accuracy": 0.94}

338

339

# Log training metadata

340

log_model_metadata({

341

"training_samples": len(data),

342

"training_time": "3600s",

343

"optimizer": "adam",

344

"learning_rate": 0.001

345

})

346

347

return model

348

349

@step

350

def evaluate_model(model: dict, test_data: list) -> dict:

351

"""Evaluate model."""

352

metrics = {

353

"accuracy": 0.94,

354

"precision": 0.92,

355

"recall": 0.95,

356

"f1": 0.93

357

}

358

359

# Log evaluation metrics

360

log_model_metadata({

361

"test_accuracy": metrics["accuracy"],

362

"test_precision": metrics["precision"],

363

"test_recall": metrics["recall"],

364

"test_f1": metrics["f1"],

365

"test_samples": len(test_data)

366

})

367

368

return metrics

369

370

@pipeline(model=model_config)

371

def full_pipeline():

372

"""Training and evaluation pipeline."""

373

data = [1, 2, 3, 4, 5]

374

model = train_model(data)

375

metrics = evaluate_model(model, [6, 7, 8])

376

return metrics

377

```

378

379

### Managing Models with Client

380

381

```python

382

from zenml.client import Client

383

from zenml.enums import ModelStages

384

385

client = Client()

386

387

# Create model namespace

388

model = client.create_model(

389

name="customer_churn_predictor",

390

license="MIT",

391

description="Predicts customer churn probability",

392

tags=["churn", "classification"]

393

)

394

395

# Create model version

396

version = client.create_model_version(

397

model_name_or_id=model.id,

398

version="1.0.0",

399

description="Initial production release",

400

tags=["production", "v1"]

401

)

402

403

# Update model version stage

404

client.update_model_version(

405

model_name_or_id=model.id,

406

version_name_or_id=version.id,

407

stage=ModelStages.PRODUCTION

408

)

409

410

# List all model versions

411

versions = client.list_model_versions(model_name_or_id=model.id)

412

for v in versions:

413

print(f"Version: {v.version}, Stage: {v.stage}")

414

415

# Get model version by stage

416

prod_version = client.get_model_version(

417

model_name_or_id=model.name,

418

version=ModelStages.PRODUCTION

419

)

420

print(f"Production version: {prod_version.version}")

421

```

422

423

### Linking Artifacts to Models

424

425

```python

426

from zenml import step, pipeline, Model, save_artifact, link_artifact_to_model

427

from zenml.client import Client

428

429

model_config = Model(name="nlp_model", version="1.0.0")

430

431

@step

432

def create_preprocessor() -> dict:

433

"""Create text preprocessor."""

434

return {"tokenizer": "bert", "max_length": 512}

435

436

@pipeline(model=model_config)

437

def training_pipeline():

438

"""Pipeline that creates related artifacts."""

439

preprocessor = create_preprocessor()

440

return preprocessor

441

442

# Run pipeline

443

training_pipeline()

444

445

# Link external artifact to model

446

model = Model(name="nlp_model", version="1.0.0")

447

448

# Save additional artifact

449

vocab_artifact = save_artifact(

450

data={"vocab": ["hello", "world"], "size": 30000},

451

name="vocabulary"

452

)

453

454

# Link to model

455

link_artifact_to_model(

456

artifact_version=vocab_artifact,

457

model=model

458

)

459

460

# List model artifacts via client

461

client = Client()

462

model_version = client.get_model_version("nlp_model", version="1.0.0")

463

artifact_links = client.list_model_version_artifact_links(

464

model_version_id=model_version.id

465

)

466

for link in artifact_links:

467

print(f"Linked artifact: {link.artifact_name}")

468

```

469

470

### Model Versioning Strategy

471

472

```python

473

from zenml import pipeline, Model

474

from datetime import datetime

475

476

# Semantic versioning

477

model_v1 = Model(name="detector", version="1.0.0")

478

model_v1_1 = Model(name="detector", version="1.1.0")

479

model_v2 = Model(name="detector", version="2.0.0")

480

481

# Date-based versioning

482

model_dated = Model(

483

name="detector",

484

version=f"v{datetime.now().strftime('%Y%m%d')}"

485

)

486

487

# Stage-based (for inference pipelines)

488

model_prod = Model(name="detector", version="production")

489

model_staging = Model(name="detector", version="staging")

490

491

# Hash-based (for reproducibility)

492

model_hash = Model(

493

name="detector",

494

version="abc123def" # Git commit hash or data hash

495

)

496

```

497