0
# Visualization
1
2
Built-in plotting functions for model interpretation, feature importance analysis, training progress monitoring, and tree structure visualization. LightGBM's visualization capabilities support both matplotlib and graphviz backends for comprehensive model analysis and presentation.
3
4
## Capabilities
5
6
### Feature Importance Plotting
7
8
Visualize feature importance scores to understand which features contribute most to model predictions.
9
10
```python { .api }
11
def plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None,
12
title='Feature importance', xlabel='Feature importance',
13
ylabel='Features', importance_type='auto', max_num_features=None,
14
ignore_zero=True, figsize=None, dpi=None, grid=True,
15
precision=3, **kwargs):
16
"""
17
Plot model's feature importance scores.
18
19
Parameters:
20
- booster: Booster or LGBMModel - Trained model to analyze
21
- ax: matplotlib.axes.Axes or None - Matplotlib axes object to plot on
22
- height: float - Bar chart height (spacing between bars)
23
- xlim: tuple or None - X-axis limits (min, max)
24
- ylim: tuple or None - Y-axis limits (min, max)
25
- title: str or None - Plot title
26
- xlabel: str or None - X-axis label
27
- ylabel: str or None - Y-axis label
28
- importance_type: str - Type of importance ('auto', 'split', 'gain')
29
- max_num_features: int or None - Maximum number of features to display
30
- ignore_zero: bool - Whether to ignore features with zero importance
31
- figsize: tuple or None - Figure size (width, height) in inches
32
- dpi: int or None - Figure resolution in dots per inch
33
- grid: bool - Whether to show grid lines
34
- precision: int - Number of decimal places for importance values
35
- **kwargs: Additional matplotlib bar plot parameters
36
37
Returns:
38
- matplotlib.axes.Axes: The matplotlib axes object with the plot
39
"""
40
```
41
42
### Training Metrics Plotting
43
44
Plot training and validation metrics over iterations to monitor model performance and detect overfitting.
45
46
```python { .api }
47
def plot_metric(eval_result, metric=None, ax=None, xlim=None, ylim=None,
48
title='Metric during training', xlabel='Iterations',
49
ylabel='auto', figsize=None, dpi=None, grid=True, **kwargs):
50
"""
51
Plot one or several metric curves from training history.
52
53
Parameters:
54
- eval_result: dict - Evaluation results from training (from record_evaluation callback)
55
- metric: str or None - Specific metric to plot (if None, plots all metrics)
56
- ax: matplotlib.axes.Axes or None - Matplotlib axes object to plot on
57
- xlim: tuple or None - X-axis limits (min, max)
58
- ylim: tuple or None - Y-axis limits (min, max)
59
- title: str or None - Plot title
60
- xlabel: str or None - X-axis label
61
- ylabel: str or 'auto' - Y-axis label ('auto' uses metric name)
62
- figsize: tuple or None - Figure size (width, height) in inches
63
- dpi: int or None - Figure resolution in dots per inch
64
- grid: bool - Whether to show grid lines
65
- **kwargs: Additional matplotlib plot parameters
66
67
Returns:
68
- matplotlib.axes.Axes: The matplotlib axes object with the plot
69
"""
70
```
71
72
### Tree Structure Visualization
73
74
Visualize individual decision trees to understand model decision-making process.
75
76
```python { .api }
77
def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None,
78
show_info=None, precision=3, orientation='horizontal',
79
**kwargs):
80
"""
81
Plot specified tree structure.
82
83
Parameters:
84
- booster: Booster or LGBMModel - Trained model containing trees
85
- ax: matplotlib.axes.Axes or None - Matplotlib axes object to plot on
86
- tree_index: int - Index of tree to visualize
87
- figsize: tuple or None - Figure size (width, height) in inches
88
- dpi: int or None - Figure resolution in dots per inch
89
- show_info: list or None - Information to show in nodes (['split_gain', 'leaf_count', etc.])
90
- precision: int - Number of decimal places for node values
91
- orientation: str - Tree layout ('horizontal' or 'vertical')
92
- **kwargs: Additional matplotlib plotting parameters
93
94
Returns:
95
- matplotlib.axes.Axes: The matplotlib axes object with the tree plot
96
"""
97
98
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
99
orientation='horizontal', **kwargs):
100
"""
101
Create graphviz digraph representation of specified tree.
102
103
Parameters:
104
- booster: Booster or LGBMModel - Trained model containing trees
105
- tree_index: int - Index of tree to visualize
106
- show_info: list or None - Information to show in nodes
107
- precision: int - Number of decimal places for node values
108
- orientation: str - Tree layout direction
109
- **kwargs: Additional graphviz parameters
110
111
Returns:
112
- graphviz.Digraph: Graphviz digraph object for the tree
113
"""
114
```
115
116
### Split Value Analysis
117
118
Analyze the distribution of split values for specific features to understand feature usage patterns.
119
120
```python { .api }
121
def plot_split_value_histogram(booster, feature, ax=None, bins=None,
122
color='auto', title='auto', xlabel='auto',
123
ylabel='Count', figsize=None, dpi=None,
124
grid=True, **kwargs):
125
"""
126
Plot histogram of split values for specified feature.
127
128
Parameters:
129
- booster: Booster or LGBMModel - Trained model to analyze
130
- feature: int or str - Feature index or name to analyze
131
- ax: matplotlib.axes.Axes or None - Matplotlib axes object to plot on
132
- bins: int or None - Number of histogram bins (auto-determined if None)
133
- color: str or 'auto' - Histogram bar color
134
- title: str or 'auto' - Plot title ('auto' generates descriptive title)
135
- xlabel: str or 'auto' - X-axis label ('auto' uses feature name)
136
- ylabel: str - Y-axis label
137
- figsize: tuple or None - Figure size (width, height) in inches
138
- dpi: int or None - Figure resolution in dots per inch
139
- grid: bool - Whether to show grid lines
140
- **kwargs: Additional matplotlib histogram parameters
141
142
Returns:
143
- matplotlib.axes.Axes: The matplotlib axes object with the histogram
144
"""
145
```
146
147
## Usage Examples
148
149
### Feature Importance Visualization
150
151
```python
152
import lightgbm as lgb
153
import matplotlib.pyplot as plt
154
from sklearn.datasets import load_breast_cancer
155
from sklearn.model_selection import train_test_split
156
157
# Load and prepare data
158
X, y = load_breast_cancer(return_X_y=True)
159
feature_names = load_breast_cancer().feature_names
160
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
161
162
# Train model
163
model = lgb.LGBMClassifier(n_estimators=100, random_state=42)
164
model.fit(X_train, y_train, feature_name=list(feature_names))
165
166
# Plot feature importance
167
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
168
169
# Plot split-based importance
170
lgb.plot_importance(
171
model,
172
importance_type='split',
173
ax=ax1,
174
max_num_features=15,
175
title='Feature Importance (Split-based)',
176
xlabel='Number of splits',
177
height=0.4
178
)
179
180
# Plot gain-based importance
181
lgb.plot_importance(
182
model,
183
importance_type='gain',
184
ax=ax2,
185
max_num_features=15,
186
title='Feature Importance (Gain-based)',
187
xlabel='Total gain',
188
height=0.4
189
)
190
191
plt.tight_layout()
192
plt.show()
193
```
194
195
### Training Progress Monitoring
196
197
```python
198
import lightgbm as lgb
199
import matplotlib.pyplot as plt
200
from sklearn.datasets import load_diabetes
201
from sklearn.model_selection import train_test_split
202
203
# Load data
204
X, y = load_diabetes(return_X_y=True)
205
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
206
207
# Prepare datasets
208
train_data = lgb.Dataset(X_train, label=y_train)
209
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
210
211
# Set up evaluation result recording
212
eval_result = {}
213
214
# Train model with evaluation tracking
215
model = lgb.train(
216
{
217
'objective': 'regression',
218
'metric': ['rmse', 'mae'],
219
'boosting_type': 'gbdt',
220
'num_leaves': 31,
221
'learning_rate': 0.05,
222
'verbose': -1
223
},
224
train_data,
225
num_boost_round=200,
226
valid_sets=[train_data, test_data],
227
valid_names=['train', 'test'],
228
callbacks=[
229
lgb.record_evaluation(eval_result),
230
lgb.early_stopping(20)
231
]
232
)
233
234
# Plot training curves
235
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
236
237
# Plot RMSE
238
lgb.plot_metric(
239
eval_result,
240
metric='rmse',
241
ax=ax1,
242
title='RMSE during Training',
243
ylabel='RMSE'
244
)
245
246
# Plot MAE
247
lgb.plot_metric(
248
eval_result,
249
metric='mae',
250
ax=ax2,
251
title='MAE during Training',
252
ylabel='MAE'
253
)
254
255
plt.tight_layout()
256
plt.show()
257
258
# Print best scores
259
print(f"Best iteration: {model.best_iteration}")
260
print(f"Best RMSE: {eval_result['test']['rmse'][model.best_iteration-1]:.4f}")
261
print(f"Best MAE: {eval_result['test']['mae'][model.best_iteration-1]:.4f}")
262
```
263
264
### Tree Structure Visualization
265
266
```python
267
import lightgbm as lgb
268
import matplotlib.pyplot as plt
269
from sklearn.datasets import load_iris
270
271
# Load simple dataset for clear tree visualization
272
X, y = load_iris(return_X_y=True)
273
feature_names = load_iris().feature_names
274
275
# Train small model for interpretable trees
276
model = lgb.LGBMClassifier(
277
n_estimators=3,
278
max_depth=3,
279
num_leaves=7,
280
random_state=42
281
)
282
model.fit(X, y, feature_name=list(feature_names))
283
284
# Visualize first few trees
285
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
286
287
for i in range(3):
288
lgb.plot_tree(
289
model,
290
tree_index=i,
291
ax=axes[i],
292
figsize=(6, 6),
293
show_info=['split_gain', 'leaf_value'],
294
precision=2
295
)
296
axes[i].set_title(f'Tree {i}')
297
298
plt.tight_layout()
299
plt.show()
300
301
# Alternative: Create graphviz digraph for higher quality
302
try:
303
import graphviz
304
305
# Create digraph for first tree
306
graph = lgb.create_tree_digraph(
307
model,
308
tree_index=0,
309
show_info=['split_gain', 'leaf_value', 'leaf_count'],
310
precision=2
311
)
312
313
# Render to file
314
graph.render('tree_0', format='png', cleanup=True)
315
print("Tree digraph saved as tree_0.png")
316
317
except ImportError:
318
print("Graphviz not available. Install with: pip install graphviz")
319
```
320
321
### Split Value Analysis
322
323
```python
324
import lightgbm as lgb
325
import matplotlib.pyplot as plt
326
import numpy as np
327
from sklearn.datasets import make_regression
328
329
# Generate data with known relationships
330
X, y = make_regression(n_samples=10000, n_features=10, noise=0.1, random_state=42)
331
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
332
333
# Train model
334
model = lgb.LGBMRegressor(
335
n_estimators=100,
336
max_depth=6,
337
random_state=42
338
)
339
model.fit(X, y, feature_name=feature_names)
340
341
# Analyze split values for top features
342
top_features = np.argsort(model.feature_importances_)[-4:] # Top 4 features
343
344
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
345
axes = axes.ravel()
346
347
for i, feature_idx in enumerate(top_features):
348
lgb.plot_split_value_histogram(
349
model,
350
feature=feature_idx,
351
ax=axes[i],
352
bins=30,
353
color='skyblue',
354
alpha=0.7
355
)
356
357
# Add feature importance to title
358
importance = model.feature_importances_[feature_idx]
359
axes[i].set_title(f'{feature_names[feature_idx]} (Importance: {importance:.0f})')
360
361
plt.tight_layout()
362
plt.show()
363
364
# Print split statistics
365
for feature_idx in top_features:
366
hist = model.booster_.get_split_value_histogram(feature_idx)
367
print(f"\n{feature_names[feature_idx]}:")
368
print(f" Number of splits: {len(hist[1])}")
369
print(f" Split range: [{hist[0][0]:.3f}, {hist[0][-1]:.3f}]")
370
print(f" Most frequent split: {hist[0][np.argmax(hist[1])]:.3f}")
371
```
372
373
### Comprehensive Model Analysis Dashboard
374
375
```python
376
import lightgbm as lgb
377
import matplotlib.pyplot as plt
378
from sklearn.datasets import load_boston
379
from sklearn.model_selection import train_test_split
380
381
# Load data
382
X, y = load_boston(return_X_y=True)
383
feature_names = load_boston().feature_names
384
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
385
386
# Train model with evaluation tracking
387
train_data = lgb.Dataset(X_train, label=y_train, feature_name=list(feature_names))
388
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
389
390
eval_result = {}
391
model = lgb.train(
392
{
393
'objective': 'regression',
394
'metric': ['rmse', 'mae'],
395
'num_leaves': 31,
396
'learning_rate': 0.05,
397
'verbose': -1
398
},
399
train_data,
400
num_boost_round=150,
401
valid_sets=[train_data, test_data],
402
valid_names=['train', 'test'],
403
callbacks=[lgb.record_evaluation(eval_result)]
404
)
405
406
# Create comprehensive dashboard
407
fig = plt.figure(figsize=(20, 12))
408
409
# 1. Feature importance
410
ax1 = plt.subplot(2, 3, 1)
411
lgb.plot_importance(model, ax=ax1, max_num_features=10, importance_type='gain')
412
ax1.set_title('Feature Importance (Gain)')
413
414
# 2. Training curves
415
ax2 = plt.subplot(2, 3, 2)
416
lgb.plot_metric(eval_result, metric='rmse', ax=ax2)
417
ax2.set_title('RMSE During Training')
418
419
# 3. Tree structure (first tree)
420
ax3 = plt.subplot(2, 3, 3)
421
lgb.plot_tree(model, tree_index=0, ax=ax3, show_info=['split_gain'])
422
ax3.set_title('First Tree Structure')
423
424
# 4. Split histogram for most important feature
425
top_feature = np.argsort(model.feature_importance())[-1]
426
ax4 = plt.subplot(2, 3, 4)
427
lgb.plot_split_value_histogram(model, feature=top_feature, ax=ax4)
428
ax4.set_title(f'Split Values: {feature_names[top_feature]}')
429
430
# 5. Model predictions vs actual
431
ax5 = plt.subplot(2, 3, 5)
432
predictions = model.predict(X_test)
433
ax5.scatter(y_test, predictions, alpha=0.6)
434
ax5.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
435
ax5.set_xlabel('Actual Values')
436
ax5.set_ylabel('Predicted Values')
437
ax5.set_title('Predictions vs Actual')
438
439
# 6. Residuals plot
440
ax6 = plt.subplot(2, 3, 6)
441
residuals = predictions - y_test
442
ax6.scatter(predictions, residuals, alpha=0.6)
443
ax6.axhline(y=0, color='r', linestyle='--')
444
ax6.set_xlabel('Predicted Values')
445
ax6.set_ylabel('Residuals')
446
ax6.set_title('Residual Plot')
447
448
plt.tight_layout()
449
plt.show()
450
451
# Print model summary
452
print(f"Model Performance:")
453
print(f"Best iteration: {model.best_iteration}")
454
print(f"Test RMSE: {eval_result['test']['rmse'][-1]:.4f}")
455
print(f"Test MAE: {eval_result['test']['mae'][-1]:.4f}")
456
print(f"Number of trees: {model.num_trees()}")
457
print(f"Number of features: {model.num_feature()}")
458
```
459
460
## Customization Options
461
462
### Matplotlib Styling
463
464
All plotting functions accept standard matplotlib parameters for customization:
465
466
```python
467
# Custom styling example
468
lgb.plot_importance(
469
model,
470
figsize=(10, 8),
471
color='darkblue',
472
alpha=0.8,
473
edgecolor='black',
474
linewidth=1.5,
475
grid=True,
476
title='Custom Styled Feature Importance',
477
xlabel='Importance Score',
478
fontsize=12
479
)
480
481
# Apply matplotlib style
482
plt.style.use('seaborn-v0_8') # Or any other style
483
lgb.plot_metric(eval_result, metric='rmse')
484
```
485
486
### Saving Plots
487
488
```python
489
# Save plot to file
490
ax = lgb.plot_importance(model, figsize=(10, 6))
491
ax.figure.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
492
493
# Save multiple plots
494
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
495
lgb.plot_importance(model, ax=axes[0,0])
496
lgb.plot_metric(eval_result, ax=axes[0,1])
497
# ... add more plots
498
plt.savefig('model_analysis.pdf', bbox_inches='tight')
499
```