or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

blackbox.mddata.mdglassbox.mdgreybox.mdindex.mdperformance.mdprivacy.mdutils.mdvisualization.md

greybox.mddocs/

0

# Tree-Specific Explanation (Greybox)

1

2

Specialized explanation methods optimized for tree-based models. These explainers leverage the internal structure of decision trees, random forests, and gradient boosting models to provide more efficient and accurate explanations than model-agnostic approaches.

3

4

## Capabilities

5

6

### SHAP Tree Explainer

7

8

Optimized SHAP implementation for tree-based models that provides exact Shapley values by exploiting the tree structure. Much faster than kernel SHAP for tree models while maintaining theoretical guarantees.

9

10

```python { .api }

11

class ShapTree:

12

def __init__(

13

self,

14

model,

15

data,

16

feature_names=None,

17

feature_types=None,

18

model_output='raw',

19

**kwargs

20

):

21

"""

22

SHAP tree explainer for tree-based models.

23

24

Parameters:

25

model: Tree-based model (scikit-learn trees, XGBoost, LightGBM, CatBoost)

26

data (array-like): Background data for computing expectations

27

feature_names (list, optional): Names for features

28

feature_types (list, optional): Types for features

29

model_output (str): Type of model output ('raw', 'probability', 'log_loss')

30

**kwargs: Additional arguments for TreeExplainer

31

"""

32

33

def explain_local(self, X, y=None, name=None, interactions=False, **kwargs):

34

"""

35

Generate SHAP explanations for tree model predictions.

36

37

Parameters:

38

X (array-like): Instances to explain

39

y (array-like, optional): True labels

40

name (str, optional): Name for explanation

41

interactions (bool): Whether to compute interaction values

42

**kwargs: Additional arguments

43

44

Returns:

45

Explanation object with SHAP values optimized for trees

46

"""

47

48

def explain_global(self, name=None, max_display=20):

49

"""

50

Generate global SHAP summary for tree model.

51

52

Parameters:

53

name (str, optional): Name for explanation

54

max_display (int): Maximum features to display

55

56

Returns:

57

Global explanation with feature importance rankings

58

"""

59

```

60

61

### Tree Interpreter

62

63

Direct interpretation of tree model decisions by tracing prediction paths and computing feature contributions at each decision node.

64

65

```python { .api }

66

class TreeInterpreter:

67

def __init__(

68

self,

69

model,

70

feature_names=None,

71

**kwargs

72

):

73

"""

74

Tree interpreter for decision path analysis.

75

76

Parameters:

77

model: Tree-based model (scikit-learn trees, random forests)

78

feature_names (list, optional): Names for features

79

**kwargs: Additional arguments

80

"""

81

82

def explain_local(self, X, y=None, name=None, **kwargs):

83

"""

84

Explain predictions by analyzing decision paths.

85

86

Parameters:

87

X (array-like): Instances to explain

88

y (array-like, optional): True labels

89

name (str, optional): Name for explanation

90

**kwargs: Additional arguments

91

92

Returns:

93

Explanation object with decision path contributions

94

"""

95

96

def explain_global(self, name=None):

97

"""

98

Generate global tree structure analysis.

99

100

Parameters:

101

name (str, optional): Name for explanation

102

103

Returns:

104

Global explanation with tree structure insights

105

"""

106

```

107

108

## Usage Examples

109

110

### SHAP Tree Explainer with Random Forest

111

112

```python

113

from interpret.greybox import ShapTree

114

from interpret import show

115

from sklearn.ensemble import RandomForestClassifier

116

from sklearn.datasets import load_breast_cancer

117

from sklearn.model_selection import train_test_split

118

119

# Load data and train model

120

data = load_breast_cancer()

121

X_train, X_test, y_train, y_test = train_test_split(

122

data.data, data.target, test_size=0.2, random_state=42

123

)

124

125

rf = RandomForestClassifier(n_estimators=100, random_state=42)

126

rf.fit(X_train, y_train)

127

128

# Create SHAP tree explainer

129

shap_tree = ShapTree(

130

model=rf,

131

data=X_train,

132

feature_names=data.feature_names,

133

model_output='probability'

134

)

135

136

# Get local explanations (much faster than kernel SHAP)

137

local_exp = shap_tree.explain_local(X_test[:10], name="SHAP Tree Local")

138

show(local_exp)

139

140

# Get global explanations

141

global_exp = shap_tree.explain_global(name="SHAP Tree Global")

142

show(global_exp)

143

```

144

145

### SHAP with Interaction Values

146

147

```python

148

# Compute interaction effects for tree models

149

interaction_exp = shap_tree.explain_local(

150

X_test[:5],

151

interactions=True,

152

name="SHAP Interactions"

153

)

154

show(interaction_exp)

155

```

156

157

### Tree Interpreter Analysis

158

159

```python

160

from interpret.greybox import TreeInterpreter

161

from sklearn.tree import DecisionTreeClassifier

162

163

# Train single decision tree for clearer interpretation

164

tree = DecisionTreeClassifier(max_depth=5, random_state=42)

165

tree.fit(X_train, y_train)

166

167

# Create tree interpreter

168

tree_interp = TreeInterpreter(

169

model=tree,

170

feature_names=data.feature_names

171

)

172

173

# Analyze decision paths

174

path_exp = tree_interp.explain_local(X_test[:5], name="Decision Paths")

175

show(path_exp)

176

177

# Global tree analysis

178

tree_global = tree_interp.explain_global(name="Tree Structure")

179

show(tree_global)

180

```

181

182

### Gradient Boosting with SHAP

183

184

```python

185

from sklearn.ensemble import GradientBoostingClassifier

186

187

# Train gradient boosting model

188

gbm = GradientBoostingClassifier(n_estimators=100, random_state=42)

189

gbm.fit(X_train, y_train)

190

191

# SHAP tree explainer works with gradient boosting too

192

shap_gbm = ShapTree(

193

model=gbm,

194

data=X_train[:200], # Sample background data

195

feature_names=data.feature_names

196

)

197

198

gbm_exp = shap_gbm.explain_local(X_test[:5], name="GBM SHAP")

199

show(gbm_exp)

200

```

201

202

### XGBoost Integration

203

204

```python

205

import xgboost as xgb

206

207

# Train XGBoost model

208

xgb_model = xgb.XGBClassifier(n_estimators=100, random_state=42)

209

xgb_model.fit(X_train, y_train)

210

211

# SHAP tree explainer supports XGBoost natively

212

shap_xgb = ShapTree(

213

model=xgb_model,

214

data=X_train[:200],

215

feature_names=data.feature_names

216

)

217

218

xgb_exp = shap_xgb.explain_local(X_test[:5], name="XGBoost SHAP")

219

show(xgb_exp)

220

221

# Global analysis

222

xgb_global = shap_xgb.explain_global(name="XGBoost Global")

223

show(xgb_global)

224

```

225

226

### Performance Comparison

227

228

```python

229

import time

230

from interpret.blackbox import ShapKernel

231

232

# Compare performance: Tree SHAP vs Kernel SHAP

233

instances = X_test[:10]

234

235

# Tree SHAP (optimized)

236

start_time = time.time()

237

tree_exp = shap_tree.explain_local(instances, name="Tree SHAP")

238

tree_time = time.time() - start_time

239

240

# Kernel SHAP (model-agnostic)

241

kernel_shap = ShapKernel(

242

predict_fn=rf.predict_proba,

243

data=X_train[:100],

244

feature_names=data.feature_names

245

)

246

247

start_time = time.time()

248

kernel_exp = kernel_shap.explain_local(instances, name="Kernel SHAP")

249

kernel_time = time.time() - start_time

250

251

print(f"Tree SHAP time: {tree_time:.2f}s")

252

print(f"Kernel SHAP time: {kernel_time:.2f}s")

253

print(f"Speedup: {kernel_time/tree_time:.1f}x")

254

255

# Show both results

256

show(tree_exp)

257

show(kernel_exp)

258

```

259

260

## Model Support

261

262

### Supported Tree Models

263

264

SHAP Tree Explainer supports:

265

- scikit-learn: `DecisionTreeClassifier`, `DecisionTreeRegressor`, `RandomForestClassifier`, `RandomForestRegressor`, `ExtraTreesClassifier`, `ExtraTreesRegressor`, `GradientBoostingClassifier`, `GradientBoostingRegressor`

266

- XGBoost: `XGBClassifier`, `XGBRegressor`, `XGBRanker`

267

- LightGBM: `LGBMClassifier`, `LGBMRegressor`, `LGBMRanker`

268

- CatBoost: `CatBoostClassifier`, `CatBoostRegressor`

269

270

Tree Interpreter supports:

271

- scikit-learn tree models

272

- Some ensemble methods (with limitations)

273

274

### Integration Tips

275

276

```python

277

# For XGBoost, ensure proper objective

278

xgb_model = xgb.XGBClassifier(objective='binary:logistic')

279

280

# For multi-class, SHAP returns values for each class

281

# Use model_output='probability' for probability outputs

282

shap_multi = ShapTree(model, data, model_output='probability')

283

284

# For regression models

285

shap_reg = ShapTree(regression_model, data, model_output='raw')

286

```