or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

classification.mdclustering.mddata-utilities.mdfeatures.mdindex.mdmodel-selection.mdregression.mdtext.md

classification.mddocs/

0

# Classification Analysis

1

2

Comprehensive visualizers for evaluating classification model performance, providing insights into prediction accuracy, class distributions, decision boundaries, and threshold optimization. These tools support both binary and multi-class classification problems.

3

4

## Capabilities

5

6

### ROC/AUC Analysis

7

8

ROC (Receiver Operating Characteristic) curves and AUC (Area Under Curve) analysis for binary and multi-class classification models. Visualizes the trade-off between true positive rate and false positive rate across different classification thresholds.

9

10

```python { .api }

11

class ROCAUC(ClassificationScoreVisualizer):

12

"""

13

ROC/AUC visualizer for classification models.

14

15

Parameters:

16

- estimator: scikit-learn classifier

17

- ax: matplotlib axes object, axes to plot on

18

- micro: bool, whether to plot micro-averaged ROC for multi-class (default: True)

19

- macro: bool, whether to plot macro-averaged ROC for multi-class (default: True)

20

- per_class: bool, whether to plot per-class ROC curves (default: True)

21

- binary: bool, whether to force binary classification mode (default: False)

22

- classes: list of class labels for display

23

- encoder: label encoder for transforming class labels

24

- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)

25

- force_model: bool, whether to force model usage even if not required

26

"""

27

def __init__(self, estimator, ax=None, micro=True, macro=True, per_class=True, binary=False, classes=None, encoder=None, is_fitted="auto", force_model=False, **kwargs): ...

28

def fit(self, X, y, **kwargs): ...

29

def score(self, X, y, **kwargs): ...

30

def show(self, **kwargs): ...

31

32

def roc_auc(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):

33

"""

34

Functional API for ROC/AUC visualization.

35

36

Parameters:

37

- estimator: scikit-learn classifier

38

- X_train: training features

39

- y_train: training labels

40

- X_test: test features (optional)

41

- y_test: test labels (optional)

42

- classes: list of class labels

43

44

Returns:

45

ROCAUC visualizer instance

46

"""

47

```

48

49

**Usage Example:**

50

51

```python

52

from yellowbrick.classifier import ROCAUC, roc_auc

53

from sklearn.ensemble import RandomForestClassifier

54

from sklearn.model_selection import train_test_split

55

56

# Class-based API

57

model = RandomForestClassifier()

58

visualizer = ROCAUC(model, classes=['Benign', 'Malignant'])

59

visualizer.fit(X_train, y_train)

60

visualizer.score(X_test, y_test)

61

visualizer.show()

62

63

# Functional API

64

roc_auc(model, X_train, y_train, X_test, y_test, classes=['Benign', 'Malignant'])

65

```

66

67

### Confusion Matrix

68

69

Confusion matrix visualization showing prediction accuracy and error patterns across different classes. Displays counts or percentages with customizable color schemes and normalization options.

70

71

```python { .api }

72

class ConfusionMatrix(ClassificationScoreVisualizer):

73

"""

74

Confusion matrix visualizer for classification models.

75

76

Parameters:

77

- estimator: scikit-learn classifier

78

- ax: matplotlib axes object, axes to plot on

79

- sample_weight: array-like of sample weights

80

- percent: bool, whether to display percentages instead of counts (default: False)

81

- classes: list of class labels for display

82

- encoder: label encoder for transforming class labels

83

- cmap: str, matplotlib colormap name (default: "YlOrRd")

84

- fontsize: int, font size for matrix text

85

- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)

86

- force_model: bool, whether to force model usage even if not required

87

"""

88

def __init__(self, estimator, ax=None, sample_weight=None, percent=False, classes=None, encoder=None, cmap="YlOrRd", fontsize=None, is_fitted="auto", force_model=False, **kwargs): ...

89

def fit(self, X, y, **kwargs): ...

90

def score(self, X, y, **kwargs): ...

91

def show(self, **kwargs): ...

92

93

def confusion_matrix(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):

94

"""

95

Functional API for confusion matrix visualization.

96

97

Parameters:

98

- estimator: scikit-learn classifier

99

- X_train: training features

100

- y_train: training labels

101

- X_test: test features (optional)

102

- y_test: test labels (optional)

103

- classes: list of class labels

104

105

Returns:

106

ConfusionMatrix visualizer instance

107

"""

108

```

109

110

### Classification Report

111

112

Heatmap visualization of classification metrics including precision, recall, F1-score, and support for each class. Provides a comprehensive overview of model performance across all classes.

113

114

```python { .api }

115

class ClassificationReport(ClassificationScoreVisualizer):

116

"""

117

Classification report heatmap visualizer.

118

119

Parameters:

120

- estimator: scikit-learn classifier

121

- classes: list of class labels for display

122

- sample_weight: array-like of sample weights

123

- support: bool, whether to draw support column

124

- cmap: matplotlib colormap for heatmap

125

"""

126

def __init__(self, estimator, classes=None, sample_weight=None, support=True, cmap='RdYlBu_r', **kwargs): ...

127

def fit(self, X, y, **kwargs): ...

128

def score(self, X, y, **kwargs): ...

129

def show(self, **kwargs): ...

130

131

def classification_report(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):

132

"""

133

Functional API for classification report visualization.

134

135

Parameters:

136

- estimator: scikit-learn classifier

137

- X_train: training features

138

- y_train: training labels

139

- X_test: test features (optional)

140

- y_test: test labels (optional)

141

- classes: list of class labels

142

143

Returns:

144

ClassificationReport visualizer instance

145

"""

146

```

147

148

### Class Prediction Error

149

150

Bar chart showing the difference between actual and predicted class distributions, helping identify systematic prediction biases and class imbalance issues.

151

152

```python { .api }

153

class ClassPredictionError(ClassificationScoreVisualizer):

154

"""

155

Class prediction error visualizer.

156

157

Parameters:

158

- estimator: scikit-learn classifier

159

- classes: list of class labels for display

160

"""

161

def __init__(self, estimator, classes=None, **kwargs): ...

162

def fit(self, X, y, **kwargs): ...

163

def score(self, X, y, **kwargs): ...

164

def show(self, **kwargs): ...

165

166

def class_prediction_error(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):

167

"""

168

Functional API for class prediction error visualization.

169

170

Parameters:

171

- estimator: scikit-learn classifier

172

- X_train: training features

173

- y_train: training labels

174

- X_test: test features (optional)

175

- y_test: test labels (optional)

176

- classes: list of class labels

177

178

Returns:

179

ClassPredictionError visualizer instance

180

"""

181

```

182

183

### Precision-Recall Curves

184

185

Precision-Recall curves for evaluating binary and multi-class classifiers, particularly useful for imbalanced datasets where ROC curves may be overly optimistic.

186

187

```python { .api }

188

class PrecisionRecallCurve(ClassificationScoreVisualizer):

189

"""

190

Precision-Recall curve visualizer.

191

192

Parameters:

193

- estimator: scikit-learn classifier

194

- classes: list of class labels for display

195

- binary: bool, whether to force binary classification mode

196

- micro: bool, whether to plot micro-averaged PR curve

197

- per_class: bool, whether to plot per-class PR curves

198

- iso_f1_curves: bool, whether to draw iso-F1 curves

199

- fill_area: bool, whether to fill area under curve

200

- ap_score: bool, whether to annotate average precision score

201

"""

202

def __init__(self, estimator, classes=None, binary=False, micro=True, per_class=True, iso_f1_curves=False, fill_area=True, ap_score=True, **kwargs): ...

203

def fit(self, X, y, **kwargs): ...

204

def score(self, X, y, **kwargs): ...

205

def show(self, **kwargs): ...

206

207

# Alias for compatibility

208

PRCurve = PrecisionRecallCurve

209

210

def precision_recall_curve(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):

211

"""

212

Functional API for precision-recall curve visualization.

213

214

Parameters:

215

- estimator: scikit-learn classifier

216

- X_train: training features

217

- y_train: training labels

218

- X_test: test features (optional)

219

- y_test: test labels (optional)

220

- classes: list of class labels

221

222

Returns:

223

PrecisionRecallCurve visualizer instance

224

"""

225

```

226

227

### Discrimination Threshold

228

229

Visualization of precision, recall, F1-score, and queue rate across different classification thresholds, helping optimize threshold selection for specific business requirements.

230

231

```python { .api }

232

class DiscriminationThreshold(ClassificationScoreVisualizer):

233

"""

234

Discrimination threshold visualizer for binary classification.

235

236

Parameters:

237

- estimator: scikit-learn binary classifier

238

- n_trials: int, number of threshold points to evaluate

239

- random_state: int, random state for reproducibility

240

"""

241

def __init__(self, estimator, n_trials=50, random_state=None, **kwargs): ...

242

def fit(self, X, y, **kwargs): ...

243

def score(self, X, y, **kwargs): ...

244

def show(self, **kwargs): ...

245

246

def discrimination_threshold(estimator, X_train, y_train, X_test=None, y_test=None, **kwargs):

247

"""

248

Functional API for discrimination threshold visualization.

249

250

Parameters:

251

- estimator: scikit-learn binary classifier

252

- X_train: training features

253

- y_train: training labels

254

- X_test: test features (optional)

255

- y_test: test labels (optional)

256

257

Returns:

258

DiscriminationThreshold visualizer instance

259

"""

260

```

261

262

### Class Balance

263

264

Visualization of class distribution in the dataset, helping identify class imbalance issues that may affect model performance.

265

266

```python { .api }

267

class ClassBalance(Visualizer):

268

"""

269

Class balance visualizer for examining target class distributions.

270

271

Parameters:

272

- labels: list of class labels for display

273

"""

274

def __init__(self, labels=None, **kwargs): ...

275

def fit(self, y, **kwargs): ...

276

def show(self, **kwargs): ...

277

278

def class_balance(y, labels=None, **kwargs):

279

"""

280

Functional API for class balance visualization.

281

282

Parameters:

283

- y: target labels

284

- labels: list of class labels for display

285

286

Returns:

287

ClassBalance visualizer instance

288

"""

289

```

290

291

## Base Classes

292

293

```python { .api }

294

class ClassificationScoreVisualizer(ScoreVisualizer):

295

"""

296

Base class for classification scoring visualizers.

297

Provides common functionality for classification model evaluation.

298

"""

299

def __init__(self, estimator, **kwargs): ...

300

def fit(self, X, y, **kwargs): ...

301

def score(self, X, y, **kwargs): ...

302

```

303

304

## Usage Patterns

305

306

### Basic Classification Evaluation

307

308

```python

309

from yellowbrick.classifier import ROCAUC, ConfusionMatrix, ClassificationReport

310

from sklearn.ensemble import RandomForestClassifier

311

from sklearn.model_selection import train_test_split

312

313

# Prepare data and model

314

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

315

model = RandomForestClassifier()

316

317

# ROC/AUC Analysis

318

roc_viz = ROCAUC(model)

319

roc_viz.fit(X_train, y_train)

320

roc_viz.score(X_test, y_test)

321

roc_viz.show()

322

323

# Confusion Matrix

324

cm_viz = ConfusionMatrix(model, percent=True)

325

cm_viz.fit(X_train, y_train)

326

cm_viz.score(X_test, y_test)

327

cm_viz.show()

328

329

# Classification Report

330

cr_viz = ClassificationReport(model)

331

cr_viz.fit(X_train, y_train)

332

cr_viz.score(X_test, y_test)

333

cr_viz.show()

334

```

335

336

### Multi-class Classification Analysis

337

338

```python

339

from yellowbrick.classifier import ROCAUC, PrecisionRecallCurve

340

from sklearn.datasets import load_iris

341

from sklearn.ensemble import RandomForestClassifier

342

343

# Load multi-class dataset

344

iris = load_iris()

345

X, y = iris.data, iris.target

346

class_names = iris.target_names

347

348

# Multi-class ROC analysis

349

model = RandomForestClassifier()

350

roc_viz = ROCAUC(model, classes=class_names)

351

roc_viz.fit(X_train, y_train)

352

roc_viz.score(X_test, y_test)

353

roc_viz.show()

354

355

# Multi-class Precision-Recall

356

pr_viz = PrecisionRecallCurve(model, classes=class_names, per_class=True, micro=True)

357

pr_viz.fit(X_train, y_train)

358

pr_viz.score(X_test, y_test)

359

pr_viz.show()

360

```

361

362

### Threshold Optimization

363

364

```python

365

from yellowbrick.classifier import DiscriminationThreshold

366

from sklearn.linear_model import LogisticRegression

367

368

# Binary classification threshold analysis

369

model = LogisticRegression()

370

threshold_viz = DiscriminationThreshold(model)

371

threshold_viz.fit(X_train, y_train)

372

threshold_viz.score(X_test, y_test)

373

threshold_viz.show()

374

```