XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible, and portable
XGBoost provides comprehensive visualization tools for model interpretation, including feature importance plots, tree structure visualization, and GraphViz export capabilities. These tools help understand model behavior and decision-making processes.
Visualize feature importance scores using matplotlib with customizable styling and display options.
def plot_importance(
booster,
ax=None,
height=0.2,
xlim=None,
ylim=None,
title='Feature importance',
xlabel='F score',
ylabel='Features',
fmap='',
importance_type='weight',
max_num_features=None,
grid=True,
show_values=True,
values_format='{v}',
**kwargs
):
"""
Plot feature importance.
Parameters:
- booster: Booster object or feature importance dict
- ax: Matplotlib axes object to plot on
- height: Bar height for horizontal bar chart
- xlim: X-axis limits as tuple (min, max)
- ylim: Y-axis limits as tuple (min, max)
- title: Plot title
- xlabel: X-axis label
- ylabel: Y-axis label
- fmap: Feature map file path for feature names
- importance_type: Importance type ('weight', 'gain', 'cover', 'total_gain', 'total_cover')
- max_num_features: Maximum number of features to display
- grid: Whether to show grid lines
- show_values: Whether to show importance values on bars
- values_format: Format string for importance values
- **kwargs: Additional matplotlib arguments
Returns:
matplotlib.axes.Axes: The axes object containing the plot
"""Plot individual decision trees from the ensemble with customizable layout and styling.
def plot_tree(
booster,
fmap='',
num_trees=0,
rankdir=None,
ax=None,
**kwargs
):
"""
Plot specified tree.
Parameters:
- booster: Booster object to plot
- fmap: Feature map file path for feature names
- num_trees: Tree index to plot (0-based)
- rankdir: Direction of tree layout ('UT', 'LR', 'TB', 'BT')
- ax: Matplotlib axes object to plot on
- **kwargs: Additional graphviz or matplotlib arguments
Returns:
matplotlib.axes.Axes: The axes object containing the plot
Note:
Requires graphviz package for tree rendering.
"""Export model trees to GraphViz DOT format for external visualization and processing.
def to_graphviz(
booster,
fmap='',
num_trees=0,
rankdir=None,
yes_color=None,
no_color=None,
condition_node_params=None,
leaf_node_params=None,
**kwargs
):
"""
Convert specified tree to graphviz format.
Parameters:
- booster: Booster object to convert
- fmap: Feature map file path for feature names
- num_trees: Tree index to convert (0-based)
- rankdir: Direction of graph layout ('UT', 'LR', 'TB', 'BT')
- yes_color: Edge color for 'yes' branches (hex color)
- no_color: Edge color for 'no' branches (hex color)
- condition_node_params: Dictionary of graphviz node parameters for condition nodes
- leaf_node_params: Dictionary of graphviz node parameters for leaf nodes
- **kwargs: Additional graphviz parameters
Returns:
graphviz.Source: GraphViz source object that can be rendered
Note:
Requires graphviz package. Returned object can be saved or rendered:
- graph.render('tree', format='png') saves to file
- graph.view() opens in viewer
"""import xgboost as xgb
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
# Train model
X, y = load_boston(return_X_y=True)
dtrain = xgb.DMatrix(X, label=y)
params = {'objective': 'reg:squarederror', 'max_depth': 3}
model = xgb.train(params, dtrain, num_boost_round=100)
# Plot feature importance
fig, ax = plt.subplots(figsize=(10, 8))
xgb.plot_importance(
model,
ax=ax,
importance_type='gain',
max_num_features=10,
title='Top 10 Feature Importance (Gain)',
show_values=True
)
plt.tight_layout()
plt.show()
# Custom styling
xgb.plot_importance(
model,
height=0.5,
xlim=(0, 0.1),
grid=False,
color='green',
title='Feature Importance',
xlabel='Importance Score',
ylabel='Feature Names'
)import xgboost as xgb
import matplotlib.pyplot as plt
# Train model
dtrain = xgb.DMatrix(X, label=y)
model = xgb.train(params, dtrain, num_boost_round=5)
# Plot first tree
fig, ax = plt.subplots(figsize=(15, 10))
xgb.plot_tree(
model,
num_trees=0,
ax=ax,
rankdir='TB' # Top to bottom layout
)
plt.show()
# Plot multiple trees
fig, axes = plt.subplots(2, 2, figsize=(20, 15))
for i, ax in enumerate(axes.flat):
if i < 4: # Plot first 4 trees
xgb.plot_tree(model, num_trees=i, ax=ax)
ax.set_title(f'Tree {i}')
plt.tight_layout()
plt.show()import xgboost as xgb
# Train model
model = xgb.train(params, dtrain, num_boost_round=3)
# Export to GraphViz
graph = xgb.to_graphviz(
model,
num_trees=0,
rankdir='LR', # Left to right layout
yes_color='#0000FF', # Blue for yes branches
no_color='#FF0000', # Red for no branches
condition_node_params={'shape': 'box', 'style': 'filled', 'fillcolor': 'lightblue'},
leaf_node_params={'shape': 'ellipse', 'style': 'filled', 'fillcolor': 'lightgreen'}
)
# Save to file
graph.render('tree_visualization', format='png', cleanup=True)
# View in default viewer
graph.view()
# Get DOT source code
dot_source = graph.source
print(dot_source)# Create feature map file
feature_names = ['feature_0', 'feature_1', 'feature_2', 'price', 'age']
with open('feature_map.txt', 'w') as f:
for i, name in enumerate(feature_names):
f.write(f'{i}\t{name}\tq\n') # q for quantitative
# Use feature map in visualization
xgb.plot_importance(
model,
fmap='feature_map.txt',
title='Feature Importance with Custom Names'
)
xgb.plot_tree(
model,
fmap='feature_map.txt',
num_trees=0,
rankdir='TB'
)
graph = xgb.to_graphviz(
model,
fmap='feature_map.txt',
num_trees=0
)# Get importance scores directly
importance_weight = model.get_score(importance_type='weight')
importance_gain = model.get_score(importance_type='gain')
importance_cover = model.get_score(importance_type='cover')
print("Feature importance by weight:", importance_weight)
print("Feature importance by gain:", importance_gain)
print("Feature importance by cover:", importance_cover)
# Compare different importance types
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for ax, imp_type in zip(axes, ['weight', 'gain', 'cover']):
xgb.plot_importance(
model,
ax=ax,
importance_type=imp_type,
title=f'Feature Importance ({imp_type.title()})',
max_num_features=10
)
plt.tight_layout()
plt.show()The visualization functions require additional packages:
pip install graphvizFeature map files use tab-separated format:
<feature_id>\t<feature_name>\t<feature_type>Where feature_type can be:
q: quantitative/numericalc: categoricali: indicator/binaryInstall with Tessl CLI
npx tessl i tessl/pypi-xgboost