or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcollaborative-filtering.mdcore-training.mddata-loading.mdindex.mdinterpretation.mdmedical.mdmetrics-losses.mdtabular.mdtext.mdvision.md

interpretation.mddocs/

0

# Model Interpretation

1

2

Tools for understanding and interpreting model predictions including visualization utilities, analysis methods, and techniques for gaining insights into model behavior and decision-making processes.

3

4

## Capabilities

5

6

### Classification Interpretation

7

8

Comprehensive analysis tools for understanding classification model predictions and performance.

9

10

```python { .api }

11

class ClassificationInterpretation:

12

"""

13

Interpretation tools for classification models.

14

Provides methods to analyze predictions, visualize confusion matrices,

15

and identify model strengths and weaknesses.

16

"""

17

18

@classmethod

19

def from_learner(cls, learn, ds_idx=1, dl=None, act=None):

20

"""

21

Create interpretation from trained learner.

22

23

Parameters:

24

- learn: Trained Learner instance

25

- ds_idx: Dataset index (1 for validation)

26

- dl: Custom DataLoader (uses learner's if None)

27

- act: Activation function to apply to predictions

28

29

Returns:

30

- ClassificationInterpretation instance

31

"""

32

33

def confusion_matrix(self, slice_size=1):

34

"""

35

Compute confusion matrix for predictions.

36

37

Parameters:

38

- slice_size: Size of slice for memory management

39

40

Returns:

41

- Confusion matrix as tensor

42

"""

43

44

def plot_confusion_matrix(self, normalize=False, title='Confusion matrix',

45

cmap="Blues", figsize=None, **kwargs):

46

"""

47

Plot confusion matrix heatmap.

48

49

Parameters:

50

- normalize: Normalize confusion matrix

51

- title: Plot title

52

- cmap: Colormap for heatmap

53

- figsize: Figure size

54

- **kwargs: Additional plotting arguments

55

"""

56

57

def most_confused(self, min_val=1):

58

"""

59

Find most confused class pairs.

60

61

Parameters:

62

- min_val: Minimum confusion count to include

63

64

Returns:

65

- List of (actual, predicted, count) tuples sorted by confusion count

66

"""

67

68

def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):

69

"""

70

Plot examples with highest losses.

71

72

Parameters:

73

- k: Number of examples to show

74

- largest: Show largest losses (vs smallest)

75

- figsize: Figure size

76

- **kwargs: Additional plotting arguments

77

"""

78

79

def top_losses(self, k=None, largest=True):

80

"""

81

Get examples with highest losses.

82

83

Parameters:

84

- k: Number of examples (all if None)

85

- largest: Return largest losses (vs smallest)

86

87

Returns:

88

- Tuple of (losses, indices)

89

"""

90

91

def print_classification_report(self):

92

"""Print detailed classification report with precision, recall, F1."""

93

```

94

95

### Segmentation Interpretation

96

97

Specialized interpretation tools for segmentation models and pixel-level predictions.

98

99

```python { .api }

100

class SegmentationInterpretation:

101

"""Interpretation tools for segmentation models."""

102

103

@classmethod

104

def from_learner(cls, learn, ds_idx=1, dl=None, act=None):

105

"""Create segmentation interpretation from learner."""

106

107

def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):

108

"""Plot segmentation examples with highest losses."""

109

110

def confusion_matrix(self, slice_size=1):

111

"""Compute pixel-wise confusion matrix."""

112

113

def plot_confusion_matrix(self, normalize=False, **kwargs):

114

"""Plot segmentation confusion matrix."""

115

116

def per_class_accuracy(self):

117

"""Calculate accuracy for each segmentation class."""

118

119

def intersection_over_union(self):

120

"""Calculate IoU for each class."""

121

```

122

123

### Base Interpretation Classes

124

125

Foundation classes for building custom interpretation tools.

126

127

```python { .api }

128

class Interpretation:

129

"""Base class for model interpretation."""

130

131

def __init__(self, dl, inputs, preds, targs, decoded, losses):

132

"""

133

Initialize interpretation.

134

135

Parameters:

136

- dl: DataLoader used for predictions

137

- inputs: Model inputs

138

- preds: Raw predictions

139

- targs: Target values

140

- decoded: Decoded predictions

141

- losses: Loss values for each example

142

"""

143

144

def top_losses(self, k=None, largest=True):

145

"""Get examples with highest/lowest losses."""

146

147

def plot_top_losses(self, k, largest=True, **kwargs):

148

"""Plot examples with extreme losses."""

149

150

def plot_top_losses(interp, k, largest=True, **kwargs):

151

"""Utility function to plot top losses."""

152

```

153

154

### Gradient-Based Interpretation

155

156

Methods using gradients to understand model decisions and feature importance.

157

158

```python { .api }

159

class GradCAM:

160

"""

161

Gradient-weighted Class Activation Mapping.

162

Visualizes which parts of input are important for predictions.

163

"""

164

165

def __init__(self, learn, layer=None):

166

"""

167

Initialize GradCAM.

168

169

Parameters:

170

- learn: Trained learner

171

- layer: Target layer for activation maps (last conv layer if None)

172

"""

173

174

def __call__(self, x, class_idx=None):

175

"""

176

Generate GradCAM heatmap.

177

178

Parameters:

179

- x: Input image

180

- class_idx: Target class index (predicted class if None)

181

182

Returns:

183

- Heatmap showing important regions

184

"""

185

186

class IntegratedGradients:

187

"""

188

Integrated Gradients for feature attribution.

189

Computes gradients along straight-line path from baseline to input.

190

"""

191

192

def __init__(self, learn, baseline=None):

193

"""

194

Initialize Integrated Gradients.

195

196

Parameters:

197

- learn: Trained learner

198

- baseline: Baseline input (zeros if None)

199

"""

200

201

def attribute(self, x, target=None, n_steps=50):

202

"""

203

Compute integrated gradients attribution.

204

205

Parameters:

206

- x: Input to analyze

207

- target: Target class (predicted if None)

208

- n_steps: Number of integration steps

209

210

Returns:

211

- Attribution map

212

"""

213

214

def gradient_times_input(learn, x, target=None):

215

"""Simple gradient * input attribution method."""

216

217

def saliency_map(learn, x, target=None):

218

"""Generate saliency map from gradients."""

219

```

220

221

### Feature Importance Analysis

222

223

Tools for analyzing feature importance in different types of models.

224

225

```python { .api }

226

class FeatureImportance:

227

"""Analyze feature importance for tabular models."""

228

229

def __init__(self, learn):

230

"""Initialize with trained tabular learner."""

231

232

def permutation_importance(self, dl=None, n_repeats=5, random_state=None):

233

"""

234

Calculate permutation-based feature importance.

235

236

Parameters:

237

- dl: DataLoader (uses validation if None)

238

- n_repeats: Number of permutation repeats

239

- random_state: Random seed

240

241

Returns:

242

- Feature importance scores

243

"""

244

245

def plot_importance(self, max_vars=20, figsize=(8,6)):

246

"""Plot feature importance scores."""

247

248

def rfpimp_importance(learn, dl=None):

249

"""Random forest-style permutation importance."""

250

251

def oob_score_importance(learn, dl=None):

252

"""Out-of-bag score-based importance."""

253

```

254

255

### Prediction Analysis

256

257

Tools for analyzing and visualizing model predictions across different domains.

258

259

```python { .api }

260

def plot_predictions(learn, ds_idx=1, max_n=9, figsize=None, **kwargs):

261

"""

262

Plot model predictions with ground truth.

263

264

Parameters:

265

- learn: Trained learner

266

- ds_idx: Dataset index

267

- max_n: Maximum number of examples

268

- figsize: Figure size

269

- **kwargs: Additional plotting arguments

270

"""

271

272

def show_results(learn, ds_idx=1, dl=None, max_n=10, shuffle=True, **kwargs):

273

"""Show model results on dataset."""

274

275

class PredictionAnalyzer:

276

"""Analyze prediction patterns and model behavior."""

277

278

def __init__(self, learn, dl=None):

279

"""Initialize analyzer with learner and data."""

280

281

def prediction_distribution(self):

282

"""Analyze distribution of prediction scores."""

283

284

def confidence_analysis(self):

285

"""Analyze prediction confidence patterns."""

286

287

def error_analysis(self):

288

"""Analyze patterns in model errors."""

289

```

290

291

### Visualization Utilities

292

293

Utility functions for creating informative visualizations of model behavior.

294

295

```python { .api }

296

def plot_multi_losses(losses_list, labels=None, figsize=(12,8)):

297

"""Plot multiple loss curves for comparison."""

298

299

def plot_lr_find(learn, skip_start=5, skip_end=5, suggestion=True):

300

"""Plot learning rate finder results."""

301

302

def plot_metrics(learn, nrows=None, ncols=None, figsize=None):

303

"""Plot all tracked metrics."""

304

305

def show_batch_predictions(learn, dl=None, max_n=9, figsize=None, **kwargs):

306

"""Show batch with predictions overlaid."""

307

308

class ActivationStats:

309

"""Analyze activation statistics across model layers."""

310

311

def __init__(self, learn):

312

"""Initialize with learner."""

313

314

def stats_by_layer(self):

315

"""Get activation statistics for each layer."""

316

317

def plot_layer_stats(self, figsize=(15,5)):

318

"""Plot activation statistics."""

319

320

def dead_chart(activs, figsize=(10,5)):

321

"""Chart showing dead neurons by layer."""

322

323

def hist_chart(activs, figsize=(10,5)):

324

"""Histogram of activations by layer."""

325

```

326

327

### Model Debugging

328

329

Tools for debugging model architecture and training issues.

330

331

```python { .api }

332

class ModelDebugger:

333

"""Debug model architecture and training issues."""

334

335

def __init__(self, learn):

336

"""Initialize debugger with learner."""

337

338

def check_gradient_flow(self):

339

"""Check for gradient flow issues."""

340

341

def analyze_layer_outputs(self, x):

342

"""Analyze outputs from each layer."""

343

344

def detect_dead_neurons(self):

345

"""Detect neurons that never activate."""

346

347

def weight_distribution_analysis(self):

348

"""Analyze weight distributions across layers."""

349

350

def summary(learn, input_size=None):

351

"""Print model summary with layer details."""

352

353

def model_sizes(learn):

354

"""Analyze model memory usage by layer."""

355

356

def check_model(learn, lr=1e-3):

357

"""Run model health checks."""

358

```

359

360

### Interactive Interpretation

361

362

Tools for interactive exploration of model predictions and behavior.

363

364

```python { .api }

365

class InteractiveClassifier:

366

"""Interactive widget for exploring classification predictions."""

367

368

def __init__(self, learn, ds_idx=1):

369

"""Initialize interactive classifier."""

370

371

def show(self):

372

"""Display interactive widget."""

373

374

class InteractiveSegmentation:

375

"""Interactive widget for exploring segmentation predictions."""

376

377

def __init__(self, learn, ds_idx=1):

378

"""Initialize interactive segmentation explorer."""

379

380

def show(self):

381

"""Display interactive widget."""

382

383

def create_interpretation_widget(learn, interpretation_type='classification'):

384

"""Create appropriate interpretation widget for model type."""

385

```