0
# Model Interpretation
1
2
Tools for understanding and interpreting model predictions including visualization utilities, analysis methods, and techniques for gaining insights into model behavior and decision-making processes.
3
4
## Capabilities
5
6
### Classification Interpretation
7
8
Comprehensive analysis tools for understanding classification model predictions and performance.
9
10
```python { .api }
11
class ClassificationInterpretation:
12
"""
13
Interpretation tools for classification models.
14
Provides methods to analyze predictions, visualize confusion matrices,
15
and identify model strengths and weaknesses.
16
"""
17
18
@classmethod
19
def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
20
"""
21
Create interpretation from trained learner.
22
23
Parameters:
24
- learn: Trained Learner instance
25
- ds_idx: Dataset index (1 for validation)
26
- dl: Custom DataLoader (uses learner's if None)
27
- act: Activation function to apply to predictions
28
29
Returns:
30
- ClassificationInterpretation instance
31
"""
32
33
def confusion_matrix(self, slice_size=1):
34
"""
35
Compute confusion matrix for predictions.
36
37
Parameters:
38
- slice_size: Size of slice for memory management
39
40
Returns:
41
- Confusion matrix as tensor
42
"""
43
44
def plot_confusion_matrix(self, normalize=False, title='Confusion matrix',
45
cmap="Blues", figsize=None, **kwargs):
46
"""
47
Plot confusion matrix heatmap.
48
49
Parameters:
50
- normalize: Normalize confusion matrix
51
- title: Plot title
52
- cmap: Colormap for heatmap
53
- figsize: Figure size
54
- **kwargs: Additional plotting arguments
55
"""
56
57
def most_confused(self, min_val=1):
58
"""
59
Find most confused class pairs.
60
61
Parameters:
62
- min_val: Minimum confusion count to include
63
64
Returns:
65
- List of (actual, predicted, count) tuples sorted by confusion count
66
"""
67
68
def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):
69
"""
70
Plot examples with highest losses.
71
72
Parameters:
73
- k: Number of examples to show
74
- largest: Show largest losses (vs smallest)
75
- figsize: Figure size
76
- **kwargs: Additional plotting arguments
77
"""
78
79
def top_losses(self, k=None, largest=True):
80
"""
81
Get examples with highest losses.
82
83
Parameters:
84
- k: Number of examples (all if None)
85
- largest: Return largest losses (vs smallest)
86
87
Returns:
88
- Tuple of (losses, indices)
89
"""
90
91
def print_classification_report(self):
92
"""Print detailed classification report with precision, recall, F1."""
93
```
94
95
### Segmentation Interpretation
96
97
Specialized interpretation tools for segmentation models and pixel-level predictions.
98
99
```python { .api }
100
class SegmentationInterpretation:
101
"""Interpretation tools for segmentation models."""
102
103
@classmethod
104
def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
105
"""Create segmentation interpretation from learner."""
106
107
def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):
108
"""Plot segmentation examples with highest losses."""
109
110
def confusion_matrix(self, slice_size=1):
111
"""Compute pixel-wise confusion matrix."""
112
113
def plot_confusion_matrix(self, normalize=False, **kwargs):
114
"""Plot segmentation confusion matrix."""
115
116
def per_class_accuracy(self):
117
"""Calculate accuracy for each segmentation class."""
118
119
def intersection_over_union(self):
120
"""Calculate IoU for each class."""
121
```
122
123
### Base Interpretation Classes
124
125
Foundation classes for building custom interpretation tools.
126
127
```python { .api }
128
class Interpretation:
129
"""Base class for model interpretation."""
130
131
def __init__(self, dl, inputs, preds, targs, decoded, losses):
132
"""
133
Initialize interpretation.
134
135
Parameters:
136
- dl: DataLoader used for predictions
137
- inputs: Model inputs
138
- preds: Raw predictions
139
- targs: Target values
140
- decoded: Decoded predictions
141
- losses: Loss values for each example
142
"""
143
144
def top_losses(self, k=None, largest=True):
145
"""Get examples with highest/lowest losses."""
146
147
def plot_top_losses(self, k, largest=True, **kwargs):
148
"""Plot examples with extreme losses."""
149
150
def plot_top_losses(interp, k, largest=True, **kwargs):
151
"""Utility function to plot top losses."""
152
```
153
154
### Gradient-Based Interpretation
155
156
Methods using gradients to understand model decisions and feature importance.
157
158
```python { .api }
159
class GradCAM:
160
"""
161
Gradient-weighted Class Activation Mapping.
162
Visualizes which parts of input are important for predictions.
163
"""
164
165
def __init__(self, learn, layer=None):
166
"""
167
Initialize GradCAM.
168
169
Parameters:
170
- learn: Trained learner
171
- layer: Target layer for activation maps (last conv layer if None)
172
"""
173
174
def __call__(self, x, class_idx=None):
175
"""
176
Generate GradCAM heatmap.
177
178
Parameters:
179
- x: Input image
180
- class_idx: Target class index (predicted class if None)
181
182
Returns:
183
- Heatmap showing important regions
184
"""
185
186
class IntegratedGradients:
187
"""
188
Integrated Gradients for feature attribution.
189
Computes gradients along straight-line path from baseline to input.
190
"""
191
192
def __init__(self, learn, baseline=None):
193
"""
194
Initialize Integrated Gradients.
195
196
Parameters:
197
- learn: Trained learner
198
- baseline: Baseline input (zeros if None)
199
"""
200
201
def attribute(self, x, target=None, n_steps=50):
202
"""
203
Compute integrated gradients attribution.
204
205
Parameters:
206
- x: Input to analyze
207
- target: Target class (predicted if None)
208
- n_steps: Number of integration steps
209
210
Returns:
211
- Attribution map
212
"""
213
214
def gradient_times_input(learn, x, target=None):
215
"""Simple gradient * input attribution method."""
216
217
def saliency_map(learn, x, target=None):
218
"""Generate saliency map from gradients."""
219
```
220
221
### Feature Importance Analysis
222
223
Tools for analyzing feature importance in different types of models.
224
225
```python { .api }
226
class FeatureImportance:
227
"""Analyze feature importance for tabular models."""
228
229
def __init__(self, learn):
230
"""Initialize with trained tabular learner."""
231
232
def permutation_importance(self, dl=None, n_repeats=5, random_state=None):
233
"""
234
Calculate permutation-based feature importance.
235
236
Parameters:
237
- dl: DataLoader (uses validation if None)
238
- n_repeats: Number of permutation repeats
239
- random_state: Random seed
240
241
Returns:
242
- Feature importance scores
243
"""
244
245
def plot_importance(self, max_vars=20, figsize=(8,6)):
246
"""Plot feature importance scores."""
247
248
def rfpimp_importance(learn, dl=None):
249
"""Random forest-style permutation importance."""
250
251
def oob_score_importance(learn, dl=None):
252
"""Out-of-bag score-based importance."""
253
```
254
255
### Prediction Analysis
256
257
Tools for analyzing and visualizing model predictions across different domains.
258
259
```python { .api }
260
def plot_predictions(learn, ds_idx=1, max_n=9, figsize=None, **kwargs):
261
"""
262
Plot model predictions with ground truth.
263
264
Parameters:
265
- learn: Trained learner
266
- ds_idx: Dataset index
267
- max_n: Maximum number of examples
268
- figsize: Figure size
269
- **kwargs: Additional plotting arguments
270
"""
271
272
def show_results(learn, ds_idx=1, dl=None, max_n=10, shuffle=True, **kwargs):
273
"""Show model results on dataset."""
274
275
class PredictionAnalyzer:
276
"""Analyze prediction patterns and model behavior."""
277
278
def __init__(self, learn, dl=None):
279
"""Initialize analyzer with learner and data."""
280
281
def prediction_distribution(self):
282
"""Analyze distribution of prediction scores."""
283
284
def confidence_analysis(self):
285
"""Analyze prediction confidence patterns."""
286
287
def error_analysis(self):
288
"""Analyze patterns in model errors."""
289
```
290
291
### Visualization Utilities
292
293
Utility functions for creating informative visualizations of model behavior.
294
295
```python { .api }
296
def plot_multi_losses(losses_list, labels=None, figsize=(12,8)):
297
"""Plot multiple loss curves for comparison."""
298
299
def plot_lr_find(learn, skip_start=5, skip_end=5, suggestion=True):
300
"""Plot learning rate finder results."""
301
302
def plot_metrics(learn, nrows=None, ncols=None, figsize=None):
303
"""Plot all tracked metrics."""
304
305
def show_batch_predictions(learn, dl=None, max_n=9, figsize=None, **kwargs):
306
"""Show batch with predictions overlaid."""
307
308
class ActivationStats:
309
"""Analyze activation statistics across model layers."""
310
311
def __init__(self, learn):
312
"""Initialize with learner."""
313
314
def stats_by_layer(self):
315
"""Get activation statistics for each layer."""
316
317
def plot_layer_stats(self, figsize=(15,5)):
318
"""Plot activation statistics."""
319
320
def dead_chart(activs, figsize=(10,5)):
321
"""Chart showing dead neurons by layer."""
322
323
def hist_chart(activs, figsize=(10,5)):
324
"""Histogram of activations by layer."""
325
```
326
327
### Model Debugging
328
329
Tools for debugging model architecture and training issues.
330
331
```python { .api }
332
class ModelDebugger:
333
"""Debug model architecture and training issues."""
334
335
def __init__(self, learn):
336
"""Initialize debugger with learner."""
337
338
def check_gradient_flow(self):
339
"""Check for gradient flow issues."""
340
341
def analyze_layer_outputs(self, x):
342
"""Analyze outputs from each layer."""
343
344
def detect_dead_neurons(self):
345
"""Detect neurons that never activate."""
346
347
def weight_distribution_analysis(self):
348
"""Analyze weight distributions across layers."""
349
350
def summary(learn, input_size=None):
351
"""Print model summary with layer details."""
352
353
def model_sizes(learn):
354
"""Analyze model memory usage by layer."""
355
356
def check_model(learn, lr=1e-3):
357
"""Run model health checks."""
358
```
359
360
### Interactive Interpretation
361
362
Tools for interactive exploration of model predictions and behavior.
363
364
```python { .api }
365
class InteractiveClassifier:
366
"""Interactive widget for exploring classification predictions."""
367
368
def __init__(self, learn, ds_idx=1):
369
"""Initialize interactive classifier."""
370
371
def show(self):
372
"""Display interactive widget."""
373
374
class InteractiveSegmentation:
375
"""Interactive widget for exploring segmentation predictions."""
376
377
def __init__(self, learn, ds_idx=1):
378
"""Initialize interactive segmentation explorer."""
379
380
def show(self):
381
"""Display interactive widget."""
382
383
def create_interpretation_widget(learn, interpretation_type='classification'):
384
"""Create appropriate interpretation widget for model type."""
385
```