0
# Visualization Tools
1
2
Specialized plotting functions for machine learning model analysis including decision regions, learning curves, and confusion matrices. All plotting functions integrate with matplotlib and can be customized using standard matplotlib parameters.
3
4
## Capabilities
5
6
### Decision Region Visualization
7
8
Visualize decision boundaries and regions for 2D datasets and classifiers.
9
10
```python { .api }
11
def plot_decision_regions(X, y, clf, feature_idx=None, filler_feature_values=None,
12
filler_feature_ranges=None, ax=None, X_highlight=None,
13
res=0.02, legend=1, hide_spines=True, markers='s^oxv<>',
14
colors='red,blue,limegreen,gray,cyan', scatter_kwargs=None,
15
contourf_kwargs=None, scatter_highlight_kwargs=None):
16
"""
17
Plot decision regions for 2D feature spaces.
18
19
Parameters:
20
- X: array-like, feature matrix (shape: [n_samples, n_features])
21
- y: array-like, class labels (shape: [n_samples])
22
- clf: sklearn-compatible classifier with predict method
23
- feature_idx: list, indices of features to plot (default: [0, 1])
24
- filler_feature_values: dict, values for non-plotted features
25
- filler_feature_ranges: dict, ranges for non-plotted features
26
- ax: matplotlib axis, axis to plot on
27
- X_highlight: array-like, samples to highlight
28
- res: float, grid resolution
29
- legend: int, legend configuration (0=no legend, 1=auto, 2=upper left)
30
- hide_spines: bool, hide plot spines
31
- markers: str, marker symbols for classes
32
- colors: str, color palette for classes
33
- scatter_kwargs: dict, additional scatter plot parameters
34
- contourf_kwargs: dict, additional contour plot parameters
35
- scatter_highlight_kwargs: dict, highlight scatter parameters
36
37
Returns:
38
- ax: matplotlib axis object
39
"""
40
```
41
42
### Learning Curve Visualization
43
44
Plot learning curves to analyze model performance vs training set size or training progress.
45
46
```python { .api }
47
def plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o',
48
test_marker='^', marker_size=7, alpha=0.75,
49
scoring='misclassification error', suppress_plot=False,
50
print_model=True, style='fivethirtyeight', legend_loc='best'):
51
"""
52
Plot learning curves showing training and test performance.
53
54
Parameters:
55
- X_train: array-like, training features
56
- y_train: array-like, training labels
57
- X_test: array-like, test features
58
- y_test: array-like, test labels
59
- clf: sklearn-compatible classifier
60
- train_marker: str, marker style for training scores
61
- test_marker: str, marker style for test scores
62
- marker_size: int, size of markers
63
- alpha: float, marker transparency
64
- scoring: str, scoring metric ('misclassification error', 'accuracy', 'auc')
65
- suppress_plot: bool, suppress plot display
66
- print_model: bool, print model details
67
- style: str, matplotlib style
68
- legend_loc: str, legend location
69
70
Returns:
71
- train_scores: list, training scores
72
- test_scores: list, test scores
73
"""
74
```
75
76
### Confusion Matrix Visualization
77
78
Create visual representations of confusion matrices with customization options.
79
80
```python { .api }
81
def plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False,
82
figsize=None, cmap=None, colorbar=False, show_absolute=True,
83
show_normed=False, normed_type='all', class_names=None):
84
"""
85
Plot confusion matrix with customization options.
86
87
Parameters:
88
- conf_mat: array-like, confusion matrix
89
- hide_spines: bool, hide plot spines
90
- hide_ticks: bool, hide axis ticks
91
- figsize: tuple, figure size (width, height)
92
- cmap: str, colormap name
93
- colorbar: bool, show colorbar
94
- show_absolute: bool, show absolute counts
95
- show_normed: bool, show normalized values
96
- normed_type: str, normalization type ('all', 'pred', 'true')
97
- class_names: list, class label names
98
99
Returns:
100
- fig: matplotlib figure object
101
- ax: matplotlib axis object
102
"""
103
```
104
105
### Feature Selection Visualization
106
107
Visualize results from sequential feature selection algorithms.
108
109
```python { .api }
110
def plot_sequential_feature_selection(metric_dict, kind='std_dev', color='blue',
111
barchart=False, figsize=None):
112
"""
113
Plot sequential feature selection results.
114
115
Parameters:
116
- metric_dict: dict, metrics from SequentialFeatureSelector
117
- kind: str, plot type ('std_dev', 'std_err', 'ci')
118
- color: str, plot color
119
- barchart: bool, use bar chart instead of line plot
120
- figsize: tuple, figure size
121
122
Returns:
123
- fig: matplotlib figure object
124
- ax: matplotlib axis object
125
"""
126
```
127
128
### Linear Regression Visualization
129
130
Visualize linear regression fits and residuals.
131
132
```python { .api }
133
def plot_linear_regression(X, y, model=None, corr_func='pearsonr', scattercolor='blue',
134
fit_style='k--', legend=True, xlim='auto'):
135
"""
136
Plot linear regression fit with correlation coefficient.
137
138
Parameters:
139
- X: array-like, feature values (1D)
140
- y: array-like, target values
141
- model: sklearn-compatible regressor, fitted model
142
- corr_func: str, correlation function ('pearsonr', 'spearmanr')
143
- scattercolor: str, scatter plot color
144
- fit_style: str, regression line style
145
- legend: bool, show legend with correlation
146
- xlim: str or tuple, x-axis limits
147
148
Returns:
149
- correlation: float, correlation coefficient
150
"""
151
```
152
153
### Specialized Plot Types
154
155
Various specialized plotting functions for data analysis and visualization.
156
157
```python { .api }
158
def category_scatter(x, y, label_col, selection=None, alpha=1.0, markers='o',
159
colors=None, figsize=(7, 5)):
160
"""
161
Create scatter plot with categorical coloring.
162
163
Parameters:
164
- x: str or array-like, x-axis data
165
- y: str or array-like, y-axis data
166
- label_col: str or array-like, categorical labels
167
- selection: list, subset of categories to plot
168
- alpha: float, point transparency
169
- markers: str, marker symbols
170
- colors: list, color palette
171
- figsize: tuple, figure size
172
173
Returns:
174
- fig: matplotlib figure object
175
"""
176
177
def heatmap(ary, xlabels=None, ylabels=None, fmt='%.1f', cmap='Blues',
178
cbar=True, cbar_kws=None, figsize=None):
179
"""
180
Create heatmap visualization.
181
182
Parameters:
183
- ary: array-like, 2D data matrix
184
- xlabels: list, x-axis labels
185
- ylabels: list, y-axis labels
186
- fmt: str, number format string
187
- cmap: str, colormap name
188
- cbar: bool, show colorbar
189
- cbar_kws: dict, colorbar keyword arguments
190
- figsize: tuple, figure size
191
192
Returns:
193
- fig: matplotlib figure object
194
- ax: matplotlib axis object
195
"""
196
197
def stacked_barplot(df, bar_names=None, figsize=(8, 5), n_legend_cols=1,
198
legend_loc='best'):
199
"""
200
Create stacked bar plot from DataFrame.
201
202
Parameters:
203
- df: DataFrame, data with categories as columns
204
- bar_names: list, names for bars (uses index if None)
205
- figsize: tuple, figure size
206
- n_legend_cols: int, number of legend columns
207
- legend_loc: str, legend location
208
209
Returns:
210
- fig: matplotlib figure object
211
- ax: matplotlib axis object
212
"""
213
214
def enrichment_plot(df, colors='bgrkcy', markers=' ', linestyles='-',
215
alpha=1.0, lw=2, where='post', grid=True, count_label='Count',
216
xlim=None, ylim=None, invert_axes=False, legend_loc='best'):
217
"""
218
Create enrichment plot for feature analysis.
219
220
Parameters:
221
- df: DataFrame, enrichment data
222
- colors: str, color sequence
223
- markers: str, marker sequence
224
- linestyles: str, line style sequence
225
- alpha: float, line transparency
226
- lw: float, line width
227
- where: str, step plot style
228
- grid: bool, show grid
229
- count_label: str, y-axis label
230
- xlim: tuple, x-axis limits
231
- ylim: tuple, y-axis limits
232
- invert_axes: bool, swap x and y axes
233
- legend_loc: str, legend location
234
235
Returns:
236
- ax: matplotlib axis object
237
"""
238
239
def checkerboard_plot(ary, fmt='%.1f', figsize=None, cbar=False, cmap=None,
240
labels_x=None, labels_y=None, fontsize_data=12):
241
"""
242
Create checkerboard-style matrix plot.
243
244
Parameters:
245
- ary: array-like, 2D data matrix
246
- fmt: str, number format string
247
- figsize: tuple, figure size
248
- cbar: bool, show colorbar
249
- cmap: str, colormap name
250
- labels_x: list, x-axis labels
251
- labels_y: list, y-axis labels
252
- fontsize_data: int, font size for data values
253
254
Returns:
255
- fig: matplotlib figure object
256
- ax: matplotlib axis object
257
"""
258
259
def ecdf(x, y_label='ECDF', x_label=None, ax=None, percentile=None, **kwargs):
260
"""
261
Plot empirical cumulative distribution function.
262
263
Parameters:
264
- x: array-like, data values
265
- y_label: str, y-axis label
266
- x_label: str, x-axis label
267
- ax: matplotlib axis, axis to plot on
268
- percentile: float, percentile line to highlight
269
- kwargs: additional plot parameters
270
271
Returns:
272
- ax: matplotlib axis object
273
- (percentile_val, percentile_prob): tuple if percentile specified
274
"""
275
276
def scatterplotmatrix(X, names=None, figsize=(8, 8), alpha=1.0, **kwargs):
277
"""
278
Create scatter plot matrix for multiple variables.
279
280
Parameters:
281
- X: array-like, feature matrix
282
- names: list, variable names
283
- figsize: tuple, figure size
284
- alpha: float, point transparency
285
- kwargs: additional scatter plot parameters
286
287
Returns:
288
- fig: matplotlib figure object
289
- axes: array of axis objects
290
"""
291
292
def plot_pca_correlation_graph(X, variables_names, dimensions=(1, 2),
293
figsize=(10, 8), X_pca=None, explained_variance=None):
294
"""
295
Plot PCA correlation graph showing variable relationships.
296
297
Parameters:
298
- X: array-like, original feature matrix
299
- variables_names: list, variable names
300
- dimensions: tuple, PCA dimensions to plot
301
- figsize: tuple, figure size
302
- X_pca: array-like, pre-computed PCA transform
303
- explained_variance: array-like, explained variance ratios
304
305
Returns:
306
- fig: matplotlib figure object
307
- ax: matplotlib axis object
308
"""
309
310
def scatter_hist(x, y, hist_bins=20, hist_range=None, alpha=0.5,
311
scatter_kwargs=None, hist_kwargs=None, figsize=(5, 5)):
312
"""
313
Create scatter plot with marginal histograms.
314
315
Parameters:
316
- x: array-like, x-axis data
317
- y: array-like, y-axis data
318
- hist_bins: int, number of histogram bins
319
- hist_range: tuple, histogram range
320
- alpha: float, histogram transparency
321
- scatter_kwargs: dict, scatter plot parameters
322
- hist_kwargs: dict, histogram parameters
323
- figsize: tuple, figure size
324
325
Returns:
326
- fig: matplotlib figure object
327
- axes: dict of axis objects {'scatter', 'hist_x', 'hist_y'}
328
"""
329
330
def remove_borders(axes=None):
331
"""
332
Remove borders and spines from matplotlib plots.
333
334
Parameters:
335
- axes: matplotlib axis or list of axes, axes to modify
336
"""
337
```
338
339
## Usage Examples
340
341
### Decision Regions Example
342
343
```python
344
from mlxtend.plotting import plot_decision_regions
345
from sklearn.ensemble import RandomForestClassifier
346
from sklearn.datasets import make_classification
347
import matplotlib.pyplot as plt
348
349
# Create 2D dataset
350
X, y = make_classification(n_samples=300, n_features=2, n_redundant=0,
351
n_informative=2, random_state=42, n_clusters_per_class=1)
352
353
# Train classifier
354
clf = RandomForestClassifier(random_state=42)
355
clf.fit(X, y)
356
357
# Plot decision regions
358
plot_decision_regions(X, y, clf=clf, legend=2)
359
plt.title('Random Forest Decision Regions')
360
plt.xlabel('Feature 1')
361
plt.ylabel('Feature 2')
362
plt.show()
363
```
364
365
### Learning Curves Example
366
367
```python
368
from mlxtend.plotting import plot_learning_curves
369
from sklearn.svm import SVC
370
from sklearn.datasets import make_classification
371
from sklearn.model_selection import train_test_split
372
import matplotlib.pyplot as plt
373
374
# Create dataset
375
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
376
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
377
378
# Create classifier
379
clf = SVC(random_state=42)
380
381
# Plot learning curves
382
train_scores, test_scores = plot_learning_curves(
383
X_train, y_train, X_test, y_test, clf,
384
scoring='accuracy', style='ggplot'
385
)
386
plt.title('SVM Learning Curves')
387
plt.show()
388
```
389
390
### Sequential Feature Selection Visualization
391
392
```python
393
from mlxtend.feature_selection import SequentialFeatureSelector
394
from mlxtend.plotting import plot_sequential_feature_selection
395
from sklearn.ensemble import RandomForestClassifier
396
from sklearn.datasets import make_classification
397
import matplotlib.pyplot as plt
398
399
# Create dataset
400
X, y = make_classification(n_samples=500, n_features=15, random_state=42)
401
402
# Perform sequential feature selection
403
clf = RandomForestClassifier(random_state=42)
404
sfs = SequentialFeatureSelector(clf, k_features=8, forward=True,
405
scoring='accuracy', cv=5)
406
sfs.fit(X, y)
407
408
# Plot results
409
plot_sequential_feature_selection(sfs.get_metric_dict(), kind='std_dev')
410
plt.title('Sequential Feature Selection Results')
411
plt.show()
412
```
413
414
### Confusion Matrix Example
415
416
```python
417
from mlxtend.plotting import plot_confusion_matrix
418
from sklearn.metrics import confusion_matrix
419
from sklearn.ensemble import RandomForestClassifier
420
from sklearn.datasets import make_classification
421
from sklearn.model_selection import train_test_split
422
import matplotlib.pyplot as plt
423
424
# Create dataset and train classifier
425
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, random_state=42)
426
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
427
428
clf = RandomForestClassifier(random_state=42)
429
clf.fit(X_train, y_train)
430
y_pred = clf.predict(X_test)
431
432
# Create and plot confusion matrix
433
cm = confusion_matrix(y_test, y_pred)
434
plot_confusion_matrix(cm, class_names=['Class 0', 'Class 1', 'Class 2'],
435
show_normed=True, colorbar=True)
436
plt.title('Confusion Matrix')
437
plt.show()
438
```