Fit interpretable models and explain blackbox machine learning with comprehensive interpretability tools.
—
Specialized explanation methods optimized for tree-based models. These explainers leverage the internal structure of decision trees, random forests, and gradient boosting models to provide more efficient and accurate explanations than model-agnostic approaches.
Optimized SHAP implementation for tree-based models that provides exact Shapley values by exploiting the tree structure. Much faster than kernel SHAP for tree models while maintaining theoretical guarantees.
class ShapTree:
def __init__(
self,
model,
data,
feature_names=None,
feature_types=None,
model_output='raw',
**kwargs
):
"""
SHAP tree explainer for tree-based models.
Parameters:
model: Tree-based model (scikit-learn trees, XGBoost, LightGBM, CatBoost)
data (array-like): Background data for computing expectations
feature_names (list, optional): Names for features
feature_types (list, optional): Types for features
model_output (str): Type of model output ('raw', 'probability', 'log_loss')
**kwargs: Additional arguments for TreeExplainer
"""
def explain_local(self, X, y=None, name=None, interactions=False, **kwargs):
"""
Generate SHAP explanations for tree model predictions.
Parameters:
X (array-like): Instances to explain
y (array-like, optional): True labels
name (str, optional): Name for explanation
interactions (bool): Whether to compute interaction values
**kwargs: Additional arguments
Returns:
Explanation object with SHAP values optimized for trees
"""
def explain_global(self, name=None, max_display=20):
"""
Generate global SHAP summary for tree model.
Parameters:
name (str, optional): Name for explanation
max_display (int): Maximum features to display
Returns:
Global explanation with feature importance rankings
"""Direct interpretation of tree model decisions by tracing prediction paths and computing feature contributions at each decision node.
class TreeInterpreter:
def __init__(
self,
model,
feature_names=None,
**kwargs
):
"""
Tree interpreter for decision path analysis.
Parameters:
model: Tree-based model (scikit-learn trees, random forests)
feature_names (list, optional): Names for features
**kwargs: Additional arguments
"""
def explain_local(self, X, y=None, name=None, **kwargs):
"""
Explain predictions by analyzing decision paths.
Parameters:
X (array-like): Instances to explain
y (array-like, optional): True labels
name (str, optional): Name for explanation
**kwargs: Additional arguments
Returns:
Explanation object with decision path contributions
"""
def explain_global(self, name=None):
"""
Generate global tree structure analysis.
Parameters:
name (str, optional): Name for explanation
Returns:
Global explanation with tree structure insights
"""from interpret.greybox import ShapTree
from interpret import show
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# Load data and train model
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=0.2, random_state=42
)
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
# Create SHAP tree explainer
shap_tree = ShapTree(
model=rf,
data=X_train,
feature_names=data.feature_names,
model_output='probability'
)
# Get local explanations (much faster than kernel SHAP)
local_exp = shap_tree.explain_local(X_test[:10], name="SHAP Tree Local")
show(local_exp)
# Get global explanations
global_exp = shap_tree.explain_global(name="SHAP Tree Global")
show(global_exp)# Compute interaction effects for tree models
interaction_exp = shap_tree.explain_local(
X_test[:5],
interactions=True,
name="SHAP Interactions"
)
show(interaction_exp)from interpret.greybox import TreeInterpreter
from sklearn.tree import DecisionTreeClassifier
# Train single decision tree for clearer interpretation
tree = DecisionTreeClassifier(max_depth=5, random_state=42)
tree.fit(X_train, y_train)
# Create tree interpreter
tree_interp = TreeInterpreter(
model=tree,
feature_names=data.feature_names
)
# Analyze decision paths
path_exp = tree_interp.explain_local(X_test[:5], name="Decision Paths")
show(path_exp)
# Global tree analysis
tree_global = tree_interp.explain_global(name="Tree Structure")
show(tree_global)from sklearn.ensemble import GradientBoostingClassifier
# Train gradient boosting model
gbm = GradientBoostingClassifier(n_estimators=100, random_state=42)
gbm.fit(X_train, y_train)
# SHAP tree explainer works with gradient boosting too
shap_gbm = ShapTree(
model=gbm,
data=X_train[:200], # Sample background data
feature_names=data.feature_names
)
gbm_exp = shap_gbm.explain_local(X_test[:5], name="GBM SHAP")
show(gbm_exp)import xgboost as xgb
# Train XGBoost model
xgb_model = xgb.XGBClassifier(n_estimators=100, random_state=42)
xgb_model.fit(X_train, y_train)
# SHAP tree explainer supports XGBoost natively
shap_xgb = ShapTree(
model=xgb_model,
data=X_train[:200],
feature_names=data.feature_names
)
xgb_exp = shap_xgb.explain_local(X_test[:5], name="XGBoost SHAP")
show(xgb_exp)
# Global analysis
xgb_global = shap_xgb.explain_global(name="XGBoost Global")
show(xgb_global)import time
from interpret.blackbox import ShapKernel
# Compare performance: Tree SHAP vs Kernel SHAP
instances = X_test[:10]
# Tree SHAP (optimized)
start_time = time.time()
tree_exp = shap_tree.explain_local(instances, name="Tree SHAP")
tree_time = time.time() - start_time
# Kernel SHAP (model-agnostic)
kernel_shap = ShapKernel(
predict_fn=rf.predict_proba,
data=X_train[:100],
feature_names=data.feature_names
)
start_time = time.time()
kernel_exp = kernel_shap.explain_local(instances, name="Kernel SHAP")
kernel_time = time.time() - start_time
print(f"Tree SHAP time: {tree_time:.2f}s")
print(f"Kernel SHAP time: {kernel_time:.2f}s")
print(f"Speedup: {kernel_time/tree_time:.1f}x")
# Show both results
show(tree_exp)
show(kernel_exp)SHAP Tree Explainer supports:
DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor, GradientBoostingClassifier, GradientBoostingRegressorXGBClassifier, XGBRegressor, XGBRankerLGBMClassifier, LGBMRegressor, LGBMRankerCatBoostClassifier, CatBoostRegressorTree Interpreter supports:
# For XGBoost, ensure proper objective
xgb_model = xgb.XGBClassifier(objective='binary:logistic')
# For multi-class, SHAP returns values for each class
# Use model_output='probability' for probability outputs
shap_multi = ShapTree(model, data, model_output='probability')
# For regression models
shap_reg = ShapTree(regression_model, data, model_output='raw')Install with Tessl CLI
npx tessl i tessl/pypi-interpret