0
# Classification Analysis
1
2
Comprehensive visualizers for evaluating classification model performance, providing insights into prediction accuracy, class distributions, decision boundaries, and threshold optimization. These tools support both binary and multi-class classification problems.
3
4
## Capabilities
5
6
### ROC/AUC Analysis
7
8
ROC (Receiver Operating Characteristic) curves and AUC (Area Under Curve) analysis for binary and multi-class classification models. Visualizes the trade-off between true positive rate and false positive rate across different classification thresholds.
9
10
```python { .api }
11
class ROCAUC(ClassificationScoreVisualizer):
12
"""
13
ROC/AUC visualizer for classification models.
14
15
Parameters:
16
- estimator: scikit-learn classifier
17
- ax: matplotlib axes object, axes to plot on
18
- micro: bool, whether to plot micro-averaged ROC for multi-class (default: True)
19
- macro: bool, whether to plot macro-averaged ROC for multi-class (default: True)
20
- per_class: bool, whether to plot per-class ROC curves (default: True)
21
- binary: bool, whether to force binary classification mode (default: False)
22
- classes: list of class labels for display
23
- encoder: label encoder for transforming class labels
24
- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
25
- force_model: bool, whether to force model usage even if not required
26
"""
27
def __init__(self, estimator, ax=None, micro=True, macro=True, per_class=True, binary=False, classes=None, encoder=None, is_fitted="auto", force_model=False, **kwargs): ...
28
def fit(self, X, y, **kwargs): ...
29
def score(self, X, y, **kwargs): ...
30
def show(self, **kwargs): ...
31
32
def roc_auc(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
33
"""
34
Functional API for ROC/AUC visualization.
35
36
Parameters:
37
- estimator: scikit-learn classifier
38
- X_train: training features
39
- y_train: training labels
40
- X_test: test features (optional)
41
- y_test: test labels (optional)
42
- classes: list of class labels
43
44
Returns:
45
ROCAUC visualizer instance
46
"""
47
```
48
49
**Usage Example:**
50
51
```python
52
from yellowbrick.classifier import ROCAUC, roc_auc
53
from sklearn.ensemble import RandomForestClassifier
54
from sklearn.model_selection import train_test_split
55
56
# Class-based API
57
model = RandomForestClassifier()
58
visualizer = ROCAUC(model, classes=['Benign', 'Malignant'])
59
visualizer.fit(X_train, y_train)
60
visualizer.score(X_test, y_test)
61
visualizer.show()
62
63
# Functional API
64
roc_auc(model, X_train, y_train, X_test, y_test, classes=['Benign', 'Malignant'])
65
```
66
67
### Confusion Matrix
68
69
Confusion matrix visualization showing prediction accuracy and error patterns across different classes. Displays counts or percentages with customizable color schemes and normalization options.
70
71
```python { .api }
72
class ConfusionMatrix(ClassificationScoreVisualizer):
73
"""
74
Confusion matrix visualizer for classification models.
75
76
Parameters:
77
- estimator: scikit-learn classifier
78
- ax: matplotlib axes object, axes to plot on
79
- sample_weight: array-like of sample weights
80
- percent: bool, whether to display percentages instead of counts (default: False)
81
- classes: list of class labels for display
82
- encoder: label encoder for transforming class labels
83
- cmap: str, matplotlib colormap name (default: "YlOrRd")
84
- fontsize: int, font size for matrix text
85
- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
86
- force_model: bool, whether to force model usage even if not required
87
"""
88
def __init__(self, estimator, ax=None, sample_weight=None, percent=False, classes=None, encoder=None, cmap="YlOrRd", fontsize=None, is_fitted="auto", force_model=False, **kwargs): ...
89
def fit(self, X, y, **kwargs): ...
90
def score(self, X, y, **kwargs): ...
91
def show(self, **kwargs): ...
92
93
def confusion_matrix(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
94
"""
95
Functional API for confusion matrix visualization.
96
97
Parameters:
98
- estimator: scikit-learn classifier
99
- X_train: training features
100
- y_train: training labels
101
- X_test: test features (optional)
102
- y_test: test labels (optional)
103
- classes: list of class labels
104
105
Returns:
106
ConfusionMatrix visualizer instance
107
"""
108
```
109
110
### Classification Report
111
112
Heatmap visualization of classification metrics including precision, recall, F1-score, and support for each class. Provides a comprehensive overview of model performance across all classes.
113
114
```python { .api }
115
class ClassificationReport(ClassificationScoreVisualizer):
116
"""
117
Classification report heatmap visualizer.
118
119
Parameters:
120
- estimator: scikit-learn classifier
121
- classes: list of class labels for display
122
- sample_weight: array-like of sample weights
123
- support: bool, whether to draw support column
124
- cmap: matplotlib colormap for heatmap
125
"""
126
def __init__(self, estimator, classes=None, sample_weight=None, support=True, cmap='RdYlBu_r', **kwargs): ...
127
def fit(self, X, y, **kwargs): ...
128
def score(self, X, y, **kwargs): ...
129
def show(self, **kwargs): ...
130
131
def classification_report(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
132
"""
133
Functional API for classification report visualization.
134
135
Parameters:
136
- estimator: scikit-learn classifier
137
- X_train: training features
138
- y_train: training labels
139
- X_test: test features (optional)
140
- y_test: test labels (optional)
141
- classes: list of class labels
142
143
Returns:
144
ClassificationReport visualizer instance
145
"""
146
```
147
148
### Class Prediction Error
149
150
Bar chart showing the difference between actual and predicted class distributions, helping identify systematic prediction biases and class imbalance issues.
151
152
```python { .api }
153
class ClassPredictionError(ClassificationScoreVisualizer):
154
"""
155
Class prediction error visualizer.
156
157
Parameters:
158
- estimator: scikit-learn classifier
159
- classes: list of class labels for display
160
"""
161
def __init__(self, estimator, classes=None, **kwargs): ...
162
def fit(self, X, y, **kwargs): ...
163
def score(self, X, y, **kwargs): ...
164
def show(self, **kwargs): ...
165
166
def class_prediction_error(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
167
"""
168
Functional API for class prediction error visualization.
169
170
Parameters:
171
- estimator: scikit-learn classifier
172
- X_train: training features
173
- y_train: training labels
174
- X_test: test features (optional)
175
- y_test: test labels (optional)
176
- classes: list of class labels
177
178
Returns:
179
ClassPredictionError visualizer instance
180
"""
181
```
182
183
### Precision-Recall Curves
184
185
Precision-Recall curves for evaluating binary and multi-class classifiers, particularly useful for imbalanced datasets where ROC curves may be overly optimistic.
186
187
```python { .api }
188
class PrecisionRecallCurve(ClassificationScoreVisualizer):
189
"""
190
Precision-Recall curve visualizer.
191
192
Parameters:
193
- estimator: scikit-learn classifier
194
- classes: list of class labels for display
195
- binary: bool, whether to force binary classification mode
196
- micro: bool, whether to plot micro-averaged PR curve
197
- per_class: bool, whether to plot per-class PR curves
198
- iso_f1_curves: bool, whether to draw iso-F1 curves
199
- fill_area: bool, whether to fill area under curve
200
- ap_score: bool, whether to annotate average precision score
201
"""
202
def __init__(self, estimator, classes=None, binary=False, micro=True, per_class=True, iso_f1_curves=False, fill_area=True, ap_score=True, **kwargs): ...
203
def fit(self, X, y, **kwargs): ...
204
def score(self, X, y, **kwargs): ...
205
def show(self, **kwargs): ...
206
207
# Alias for compatibility
208
PRCurve = PrecisionRecallCurve
209
210
def precision_recall_curve(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
211
"""
212
Functional API for precision-recall curve visualization.
213
214
Parameters:
215
- estimator: scikit-learn classifier
216
- X_train: training features
217
- y_train: training labels
218
- X_test: test features (optional)
219
- y_test: test labels (optional)
220
- classes: list of class labels
221
222
Returns:
223
PrecisionRecallCurve visualizer instance
224
"""
225
```
226
227
### Discrimination Threshold
228
229
Visualization of precision, recall, F1-score, and queue rate across different classification thresholds, helping optimize threshold selection for specific business requirements.
230
231
```python { .api }
232
class DiscriminationThreshold(ClassificationScoreVisualizer):
233
"""
234
Discrimination threshold visualizer for binary classification.
235
236
Parameters:
237
- estimator: scikit-learn binary classifier
238
- n_trials: int, number of threshold points to evaluate
239
- random_state: int, random state for reproducibility
240
"""
241
def __init__(self, estimator, n_trials=50, random_state=None, **kwargs): ...
242
def fit(self, X, y, **kwargs): ...
243
def score(self, X, y, **kwargs): ...
244
def show(self, **kwargs): ...
245
246
def discrimination_threshold(estimator, X_train, y_train, X_test=None, y_test=None, **kwargs):
247
"""
248
Functional API for discrimination threshold visualization.
249
250
Parameters:
251
- estimator: scikit-learn binary classifier
252
- X_train: training features
253
- y_train: training labels
254
- X_test: test features (optional)
255
- y_test: test labels (optional)
256
257
Returns:
258
DiscriminationThreshold visualizer instance
259
"""
260
```
261
262
### Class Balance
263
264
Visualization of class distribution in the dataset, helping identify class imbalance issues that may affect model performance.
265
266
```python { .api }
267
class ClassBalance(Visualizer):
268
"""
269
Class balance visualizer for examining target class distributions.
270
271
Parameters:
272
- labels: list of class labels for display
273
"""
274
def __init__(self, labels=None, **kwargs): ...
275
def fit(self, y, **kwargs): ...
276
def show(self, **kwargs): ...
277
278
def class_balance(y, labels=None, **kwargs):
279
"""
280
Functional API for class balance visualization.
281
282
Parameters:
283
- y: target labels
284
- labels: list of class labels for display
285
286
Returns:
287
ClassBalance visualizer instance
288
"""
289
```
290
291
## Base Classes
292
293
```python { .api }
294
class ClassificationScoreVisualizer(ScoreVisualizer):
295
"""
296
Base class for classification scoring visualizers.
297
Provides common functionality for classification model evaluation.
298
"""
299
def __init__(self, estimator, **kwargs): ...
300
def fit(self, X, y, **kwargs): ...
301
def score(self, X, y, **kwargs): ...
302
```
303
304
## Usage Patterns
305
306
### Basic Classification Evaluation
307
308
```python
309
from yellowbrick.classifier import ROCAUC, ConfusionMatrix, ClassificationReport
310
from sklearn.ensemble import RandomForestClassifier
311
from sklearn.model_selection import train_test_split
312
313
# Prepare data and model
314
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
315
model = RandomForestClassifier()
316
317
# ROC/AUC Analysis
318
roc_viz = ROCAUC(model)
319
roc_viz.fit(X_train, y_train)
320
roc_viz.score(X_test, y_test)
321
roc_viz.show()
322
323
# Confusion Matrix
324
cm_viz = ConfusionMatrix(model, percent=True)
325
cm_viz.fit(X_train, y_train)
326
cm_viz.score(X_test, y_test)
327
cm_viz.show()
328
329
# Classification Report
330
cr_viz = ClassificationReport(model)
331
cr_viz.fit(X_train, y_train)
332
cr_viz.score(X_test, y_test)
333
cr_viz.show()
334
```
335
336
### Multi-class Classification Analysis
337
338
```python
339
from yellowbrick.classifier import ROCAUC, PrecisionRecallCurve
340
from sklearn.datasets import load_iris
341
from sklearn.ensemble import RandomForestClassifier
342
343
# Load multi-class dataset
344
iris = load_iris()
345
X, y = iris.data, iris.target
346
class_names = iris.target_names
347
348
# Multi-class ROC analysis
349
model = RandomForestClassifier()
350
roc_viz = ROCAUC(model, classes=class_names)
351
roc_viz.fit(X_train, y_train)
352
roc_viz.score(X_test, y_test)
353
roc_viz.show()
354
355
# Multi-class Precision-Recall
356
pr_viz = PrecisionRecallCurve(model, classes=class_names, per_class=True, micro=True)
357
pr_viz.fit(X_train, y_train)
358
pr_viz.score(X_test, y_test)
359
pr_viz.show()
360
```
361
362
### Threshold Optimization
363
364
```python
365
from yellowbrick.classifier import DiscriminationThreshold
366
from sklearn.linear_model import LogisticRegression
367
368
# Binary classification threshold analysis
369
model = LogisticRegression()
370
threshold_viz = DiscriminationThreshold(model)
371
threshold_viz.fit(X_train, y_train)
372
threshold_viz.score(X_test, y_test)
373
threshold_viz.show()
374
```