0
# Tree-Specific Explanation (Greybox)
1
2
Specialized explanation methods optimized for tree-based models. These explainers leverage the internal structure of decision trees, random forests, and gradient boosting models to provide more efficient and accurate explanations than model-agnostic approaches.
3
4
## Capabilities
5
6
### SHAP Tree Explainer
7
8
Optimized SHAP implementation for tree-based models that provides exact Shapley values by exploiting the tree structure. Much faster than kernel SHAP for tree models while maintaining theoretical guarantees.
9
10
```python { .api }
11
class ShapTree:
12
def __init__(
13
self,
14
model,
15
data,
16
feature_names=None,
17
feature_types=None,
18
model_output='raw',
19
**kwargs
20
):
21
"""
22
SHAP tree explainer for tree-based models.
23
24
Parameters:
25
model: Tree-based model (scikit-learn trees, XGBoost, LightGBM, CatBoost)
26
data (array-like): Background data for computing expectations
27
feature_names (list, optional): Names for features
28
feature_types (list, optional): Types for features
29
model_output (str): Type of model output ('raw', 'probability', 'log_loss')
30
**kwargs: Additional arguments for TreeExplainer
31
"""
32
33
def explain_local(self, X, y=None, name=None, interactions=False, **kwargs):
34
"""
35
Generate SHAP explanations for tree model predictions.
36
37
Parameters:
38
X (array-like): Instances to explain
39
y (array-like, optional): True labels
40
name (str, optional): Name for explanation
41
interactions (bool): Whether to compute interaction values
42
**kwargs: Additional arguments
43
44
Returns:
45
Explanation object with SHAP values optimized for trees
46
"""
47
48
def explain_global(self, name=None, max_display=20):
49
"""
50
Generate global SHAP summary for tree model.
51
52
Parameters:
53
name (str, optional): Name for explanation
54
max_display (int): Maximum features to display
55
56
Returns:
57
Global explanation with feature importance rankings
58
"""
59
```
60
61
### Tree Interpreter
62
63
Direct interpretation of tree model decisions by tracing prediction paths and computing feature contributions at each decision node.
64
65
```python { .api }
66
class TreeInterpreter:
67
def __init__(
68
self,
69
model,
70
feature_names=None,
71
**kwargs
72
):
73
"""
74
Tree interpreter for decision path analysis.
75
76
Parameters:
77
model: Tree-based model (scikit-learn trees, random forests)
78
feature_names (list, optional): Names for features
79
**kwargs: Additional arguments
80
"""
81
82
def explain_local(self, X, y=None, name=None, **kwargs):
83
"""
84
Explain predictions by analyzing decision paths.
85
86
Parameters:
87
X (array-like): Instances to explain
88
y (array-like, optional): True labels
89
name (str, optional): Name for explanation
90
**kwargs: Additional arguments
91
92
Returns:
93
Explanation object with decision path contributions
94
"""
95
96
def explain_global(self, name=None):
97
"""
98
Generate global tree structure analysis.
99
100
Parameters:
101
name (str, optional): Name for explanation
102
103
Returns:
104
Global explanation with tree structure insights
105
"""
106
```
107
108
## Usage Examples
109
110
### SHAP Tree Explainer with Random Forest
111
112
```python
113
from interpret.greybox import ShapTree
114
from interpret import show
115
from sklearn.ensemble import RandomForestClassifier
116
from sklearn.datasets import load_breast_cancer
117
from sklearn.model_selection import train_test_split
118
119
# Load data and train model
120
data = load_breast_cancer()
121
X_train, X_test, y_train, y_test = train_test_split(
122
data.data, data.target, test_size=0.2, random_state=42
123
)
124
125
rf = RandomForestClassifier(n_estimators=100, random_state=42)
126
rf.fit(X_train, y_train)
127
128
# Create SHAP tree explainer
129
shap_tree = ShapTree(
130
model=rf,
131
data=X_train,
132
feature_names=data.feature_names,
133
model_output='probability'
134
)
135
136
# Get local explanations (much faster than kernel SHAP)
137
local_exp = shap_tree.explain_local(X_test[:10], name="SHAP Tree Local")
138
show(local_exp)
139
140
# Get global explanations
141
global_exp = shap_tree.explain_global(name="SHAP Tree Global")
142
show(global_exp)
143
```
144
145
### SHAP with Interaction Values
146
147
```python
148
# Compute interaction effects for tree models
149
interaction_exp = shap_tree.explain_local(
150
X_test[:5],
151
interactions=True,
152
name="SHAP Interactions"
153
)
154
show(interaction_exp)
155
```
156
157
### Tree Interpreter Analysis
158
159
```python
160
from interpret.greybox import TreeInterpreter
161
from sklearn.tree import DecisionTreeClassifier
162
163
# Train single decision tree for clearer interpretation
164
tree = DecisionTreeClassifier(max_depth=5, random_state=42)
165
tree.fit(X_train, y_train)
166
167
# Create tree interpreter
168
tree_interp = TreeInterpreter(
169
model=tree,
170
feature_names=data.feature_names
171
)
172
173
# Analyze decision paths
174
path_exp = tree_interp.explain_local(X_test[:5], name="Decision Paths")
175
show(path_exp)
176
177
# Global tree analysis
178
tree_global = tree_interp.explain_global(name="Tree Structure")
179
show(tree_global)
180
```
181
182
### Gradient Boosting with SHAP
183
184
```python
185
from sklearn.ensemble import GradientBoostingClassifier
186
187
# Train gradient boosting model
188
gbm = GradientBoostingClassifier(n_estimators=100, random_state=42)
189
gbm.fit(X_train, y_train)
190
191
# SHAP tree explainer works with gradient boosting too
192
shap_gbm = ShapTree(
193
model=gbm,
194
data=X_train[:200], # Sample background data
195
feature_names=data.feature_names
196
)
197
198
gbm_exp = shap_gbm.explain_local(X_test[:5], name="GBM SHAP")
199
show(gbm_exp)
200
```
201
202
### XGBoost Integration
203
204
```python
205
import xgboost as xgb
206
207
# Train XGBoost model
208
xgb_model = xgb.XGBClassifier(n_estimators=100, random_state=42)
209
xgb_model.fit(X_train, y_train)
210
211
# SHAP tree explainer supports XGBoost natively
212
shap_xgb = ShapTree(
213
model=xgb_model,
214
data=X_train[:200],
215
feature_names=data.feature_names
216
)
217
218
xgb_exp = shap_xgb.explain_local(X_test[:5], name="XGBoost SHAP")
219
show(xgb_exp)
220
221
# Global analysis
222
xgb_global = shap_xgb.explain_global(name="XGBoost Global")
223
show(xgb_global)
224
```
225
226
### Performance Comparison
227
228
```python
229
import time
230
from interpret.blackbox import ShapKernel
231
232
# Compare performance: Tree SHAP vs Kernel SHAP
233
instances = X_test[:10]
234
235
# Tree SHAP (optimized)
236
start_time = time.time()
237
tree_exp = shap_tree.explain_local(instances, name="Tree SHAP")
238
tree_time = time.time() - start_time
239
240
# Kernel SHAP (model-agnostic)
241
kernel_shap = ShapKernel(
242
predict_fn=rf.predict_proba,
243
data=X_train[:100],
244
feature_names=data.feature_names
245
)
246
247
start_time = time.time()
248
kernel_exp = kernel_shap.explain_local(instances, name="Kernel SHAP")
249
kernel_time = time.time() - start_time
250
251
print(f"Tree SHAP time: {tree_time:.2f}s")
252
print(f"Kernel SHAP time: {kernel_time:.2f}s")
253
print(f"Speedup: {kernel_time/tree_time:.1f}x")
254
255
# Show both results
256
show(tree_exp)
257
show(kernel_exp)
258
```
259
260
## Model Support
261
262
### Supported Tree Models
263
264
SHAP Tree Explainer supports:
265
- scikit-learn: `DecisionTreeClassifier`, `DecisionTreeRegressor`, `RandomForestClassifier`, `RandomForestRegressor`, `ExtraTreesClassifier`, `ExtraTreesRegressor`, `GradientBoostingClassifier`, `GradientBoostingRegressor`
266
- XGBoost: `XGBClassifier`, `XGBRegressor`, `XGBRanker`
267
- LightGBM: `LGBMClassifier`, `LGBMRegressor`, `LGBMRanker`
268
- CatBoost: `CatBoostClassifier`, `CatBoostRegressor`
269
270
Tree Interpreter supports:
271
- scikit-learn tree models
272
- Some ensemble methods (with limitations)
273
274
### Integration Tips
275
276
```python
277
# For XGBoost, ensure proper objective
278
xgb_model = xgb.XGBClassifier(objective='binary:logistic')
279
280
# For multi-class, SHAP returns values for each class
281
# Use model_output='probability' for probability outputs
282
shap_multi = ShapTree(model, data, model_output='probability')
283
284
# For regression models
285
shap_reg = ShapTree(regression_model, data, model_output='raw')
286
```