0
# Interpretable Models (Glassbox)
1
2
Machine learning models that are inherently interpretable by design. These models provide transparency and explainability without requiring post-hoc explanation methods, making them ideal for high-stakes applications where understanding model behavior is critical.
3
4
## Capabilities
5
6
### Explainable Boosting Machine (EBM)
7
8
State-of-the-art interpretable machine learning algorithm that uses gradient boosting with intelligibility constraints. EBM provides accuracy competitive with blackbox models while maintaining full interpretability through additive feature contributions.
9
10
```python { .api }
11
class ExplainableBoostingClassifier:
12
def __init__(
13
self,
14
feature_names=None,
15
feature_types=None,
16
max_bins=1024,
17
max_interaction_bins=64,
18
interactions="5x",
19
exclude=None,
20
validation_size=0.15,
21
outer_bags=14,
22
inner_bags=0,
23
learning_rate=0.04,
24
greedy_ratio=10.0,
25
cyclic_progress=False,
26
smoothing_rounds=500,
27
interaction_smoothing_rounds=100,
28
max_rounds=50000,
29
early_stopping_rounds=100,
30
early_stopping_tolerance=1e-5,
31
callback=None,
32
min_samples_leaf=4,
33
min_hessian=0.0,
34
reg_alpha=0.0,
35
reg_lambda=0.0,
36
max_delta_step=0.0,
37
gain_scale=5.0,
38
min_cat_samples=10,
39
cat_smooth=10.0,
40
missing="separate",
41
max_leaves=2,
42
monotone_constraints=None,
43
objective="rmse",
44
n_jobs=-2,
45
random_state=42
46
):
47
"""
48
Explainable Boosting Machine classifier.
49
50
Parameters:
51
feature_names (list, optional): Names for features
52
feature_types (list, optional): Types for features ('continuous', 'ordinal', 'nominal')
53
max_bins (int): Maximum bins for continuous features
54
max_interaction_bins (int): Maximum bins for interaction features
55
interactions (int, float, str, or list): Number of interactions to detect or specific pairs ("5x" = 5 times features)
56
exclude (list, optional): Features to exclude from interactions
57
validation_size (float): Proportion for validation set
58
outer_bags (int): Number of outer bags for training
59
inner_bags (int): Number of inner bags for training
60
learning_rate (float): Learning rate for boosting
61
greedy_ratio (float): Ratio for greedy vs cyclic boosting
62
cyclic_progress (bool, float, int): Use cyclic boosting progress
63
smoothing_rounds (int): Rounds of smoothing
64
interaction_smoothing_rounds (int): Rounds for interaction smoothing
65
max_rounds (int): Maximum boosting rounds
66
early_stopping_rounds (int): Early stopping patience
67
early_stopping_tolerance (float): Early stopping tolerance
68
callback (callable, optional): Callback function for training progress
69
min_samples_leaf (int): Minimum samples per leaf
70
min_hessian (float): Minimum hessian for node splitting
71
reg_alpha (float): L1 regularization term
72
reg_lambda (float): L2 regularization term
73
max_delta_step (float): Maximum delta step for weight estimation
74
gain_scale (float): Gain scaling factor
75
min_cat_samples (int): Minimum samples for categorical splitting
76
cat_smooth (float): Smoothing factor for categorical features
77
missing (str): How to handle missing values ("separate", "none")
78
max_leaves (int): Maximum leaves per tree
79
monotone_constraints (list, optional): Monotonic constraints for features
80
objective (str): Loss function objective ("rmse", "log_loss", etc.)
81
n_jobs (int): Parallel jobs (-2 for all but one CPU)
82
random_state (int, optional): Random seed
83
"""
84
85
def fit(self, X, y, sample_weight=None):
86
"""Fit the EBM classifier."""
87
88
def predict(self, X):
89
"""Make predictions."""
90
91
def predict_proba(self, X):
92
"""Predict class probabilities."""
93
94
def explain_global(self, name=None):
95
"""Get global feature importance explanation."""
96
97
def explain_local(self, X, y=None, name=None):
98
"""Get local explanations for specific instances."""
99
100
class ExplainableBoostingRegressor:
101
def __init__(
102
self,
103
feature_names=None,
104
feature_types=None,
105
max_bins=1024,
106
max_interaction_bins=64,
107
interactions="3x",
108
exclude=None,
109
validation_size=0.15,
110
outer_bags=14,
111
inner_bags=0,
112
learning_rate=0.015,
113
greedy_ratio=10.0,
114
cyclic_progress=False,
115
smoothing_rounds=75,
116
interaction_smoothing_rounds=75,
117
max_rounds=50000,
118
early_stopping_rounds=100,
119
early_stopping_tolerance=1e-5,
120
callback=None,
121
min_samples_leaf=4,
122
min_hessian=1e-4,
123
reg_alpha=0.0,
124
reg_lambda=0.0,
125
max_delta_step=0.0,
126
gain_scale=5.0,
127
min_cat_samples=10,
128
cat_smooth=10.0,
129
missing="separate",
130
max_leaves=2,
131
monotone_constraints=None,
132
objective="log_loss",
133
n_jobs=-2,
134
random_state=42
135
):
136
"""
137
Explainable Boosting Machine regressor.
138
139
Parameters:
140
feature_names (list, optional): Names for features
141
feature_types (list, optional): Types for features ('continuous', 'ordinal', 'nominal')
142
max_bins (int): Maximum bins for continuous features
143
max_interaction_bins (int): Maximum bins for interaction features
144
interactions (int, float, str, or list): Number of interactions to detect or specific pairs ("3x" = 3 times features)
145
exclude (list, optional): Features to exclude from interactions
146
validation_size (float): Proportion for validation set
147
outer_bags (int): Number of outer bags for training
148
inner_bags (int): Number of inner bags for training
149
learning_rate (float): Learning rate for boosting
150
greedy_ratio (float): Ratio for greedy vs cyclic boosting
151
cyclic_progress (bool, float, int): Use cyclic boosting progress
152
smoothing_rounds (int): Rounds of smoothing
153
interaction_smoothing_rounds (int): Rounds for interaction smoothing
154
max_rounds (int): Maximum boosting rounds
155
early_stopping_rounds (int): Early stopping patience
156
early_stopping_tolerance (float): Early stopping tolerance
157
callback (callable, optional): Callback function for training progress
158
min_samples_leaf (int): Minimum samples per leaf
159
min_hessian (float): Minimum hessian for node splitting
160
reg_alpha (float): L1 regularization term
161
reg_lambda (float): L2 regularization term
162
max_delta_step (float): Maximum delta step for weight estimation
163
gain_scale (float): Gain scaling factor
164
min_cat_samples (int): Minimum samples for categorical splitting
165
cat_smooth (float): Smoothing factor for categorical features
166
missing (str): How to handle missing values ("separate", "none")
167
max_leaves (int): Maximum leaves per tree
168
monotone_constraints (list, optional): Monotonic constraints for features
169
objective (str): Loss function objective ("log_loss", "rmse", etc.)
170
n_jobs (int): Parallel jobs (-2 for all but one CPU)
171
random_state (int, optional): Random seed
172
"""
173
174
def fit(self, X, y, sample_weight=None):
175
"""Fit the EBM regressor."""
176
177
def predict(self, X):
178
"""Make predictions."""
179
180
def explain_global(self, name=None):
181
"""Get global feature importance explanation."""
182
183
def explain_local(self, X, y=None, name=None):
184
"""Get local explanations for specific instances."""
185
186
def merge_ebms(ebms):
187
"""
188
Merge multiple EBM models trained on different datasets.
189
190
Parameters:
191
ebms (list): List of trained EBM models
192
193
Returns:
194
Merged EBM model
195
"""
196
```
197
198
### Linear Models
199
200
Traditional linear regression and logistic regression models with full interpretability through coefficient inspection.
201
202
```python { .api }
203
class LinearRegression:
204
def __init__(self, feature_names=None, feature_types=None):
205
"""
206
Linear regression model.
207
208
Parameters:
209
feature_names (list, optional): Names for features
210
feature_types (list, optional): Types for features
211
"""
212
213
def fit(self, X, y):
214
"""Fit linear regression model."""
215
216
def predict(self, X):
217
"""Make predictions."""
218
219
def explain_global(self, name=None):
220
"""Get global coefficient explanation."""
221
222
class LogisticRegression:
223
def __init__(self, feature_names=None, feature_types=None):
224
"""
225
Logistic regression model.
226
227
Parameters:
228
feature_names (list, optional): Names for features
229
feature_types (list, optional): Types for features
230
"""
231
232
def fit(self, X, y):
233
"""Fit logistic regression model."""
234
235
def predict(self, X):
236
"""Make predictions."""
237
238
def predict_proba(self, X):
239
"""Predict class probabilities."""
240
241
def explain_global(self, name=None):
242
"""Get global coefficient explanation."""
243
```
244
245
### Decision Trees
246
247
Interpretable decision tree models for classification and regression with built-in explanation capabilities.
248
249
```python { .api }
250
class ClassificationTree:
251
def __init__(
252
self,
253
feature_names=None,
254
feature_types=None,
255
max_depth=None,
256
min_samples_split=2,
257
min_samples_leaf=1,
258
random_state=None
259
):
260
"""
261
Decision tree classifier.
262
263
Parameters:
264
feature_names (list, optional): Names for features
265
feature_types (list, optional): Types for features
266
max_depth (int, optional): Maximum tree depth
267
min_samples_split (int): Minimum samples to split
268
min_samples_leaf (int): Minimum samples per leaf
269
random_state (int, optional): Random seed
270
"""
271
272
def fit(self, X, y):
273
"""Fit decision tree classifier."""
274
275
def predict(self, X):
276
"""Make predictions."""
277
278
def predict_proba(self, X):
279
"""Predict class probabilities."""
280
281
def explain_global(self, name=None):
282
"""Get global tree structure explanation."""
283
284
def explain_local(self, X, y=None, name=None):
285
"""Get local decision path explanations."""
286
287
class RegressionTree:
288
def __init__(
289
self,
290
feature_names=None,
291
feature_types=None,
292
max_depth=None,
293
min_samples_split=2,
294
min_samples_leaf=1,
295
random_state=None
296
):
297
"""
298
Decision tree regressor.
299
300
Parameters: Same as ClassificationTree
301
"""
302
303
def fit(self, X, y):
304
"""Fit decision tree regressor."""
305
306
def predict(self, X):
307
"""Make predictions."""
308
309
def explain_global(self, name=None):
310
"""Get global tree structure explanation."""
311
312
def explain_local(self, X, y=None, name=None):
313
"""Get local decision path explanations."""
314
```
315
316
### Decision Lists
317
318
Interpretable rule-based classification models that provide easy-to-understand if-then rules.
319
320
```python { .api }
321
class DecisionListClassifier:
322
def __init__(
323
self,
324
feature_names=None,
325
feature_types=None,
326
max_depth=5,
327
min_samples_leaf=10,
328
random_state=None
329
):
330
"""
331
Decision list classifier using rule-based learning.
332
333
Parameters:
334
feature_names (list, optional): Names for features
335
feature_types (list, optional): Types for features
336
max_depth (int): Maximum rule depth
337
min_samples_leaf (int): Minimum samples per rule
338
random_state (int, optional): Random seed
339
"""
340
341
def fit(self, X, y):
342
"""Fit decision list classifier."""
343
344
def predict(self, X):
345
"""Make predictions."""
346
347
def predict_proba(self, X):
348
"""Predict class probabilities."""
349
350
def explain_global(self, name=None):
351
"""Get global rule list explanation."""
352
353
def explain_local(self, X, y=None, name=None):
354
"""Get local rule explanations."""
355
```
356
357
### APLR Models
358
359
Additive Piecewise Linear Regression models that combine interpretability with the ability to capture non-linear relationships.
360
361
```python { .api }
362
class APLRClassifier:
363
def __init__(
364
self,
365
feature_names=None,
366
feature_types=None,
367
random_state=None
368
):
369
"""
370
Additive Piecewise Linear Regression classifier.
371
372
Parameters:
373
feature_names (list, optional): Names for features
374
feature_types (list, optional): Types for features
375
random_state (int, optional): Random seed
376
"""
377
378
def fit(self, X, y):
379
"""Fit APLR classifier."""
380
381
def predict(self, X):
382
"""Make predictions."""
383
384
def predict_proba(self, X):
385
"""Predict class probabilities."""
386
387
def explain_global(self, name=None):
388
"""Get global piecewise linear explanation."""
389
390
def explain_local(self, X, y=None, name=None):
391
"""Get local explanations."""
392
393
class APLRRegressor:
394
def __init__(
395
self,
396
feature_names=None,
397
feature_types=None,
398
random_state=None
399
):
400
"""
401
Additive Piecewise Linear Regression regressor.
402
403
Parameters: Same as APLRClassifier
404
"""
405
406
def fit(self, X, y):
407
"""Fit APLR regressor."""
408
409
def predict(self, X):
410
"""Make predictions."""
411
412
def explain_global(self, name=None):
413
"""Get global piecewise linear explanation."""
414
415
def explain_local(self, X, y=None, name=None):
416
"""Get local explanations."""
417
```
418
419
## Usage Examples
420
421
### Training an EBM Model
422
423
```python
424
from interpret.glassbox import ExplainableBoostingClassifier
425
from interpret import show
426
from sklearn.datasets import load_breast_cancer
427
from sklearn.model_selection import train_test_split
428
429
# Load data
430
data = load_breast_cancer()
431
X_train, X_test, y_train, y_test = train_test_split(
432
data.data, data.target, test_size=0.2, random_state=42
433
)
434
435
# Train EBM with interactions
436
ebm = ExplainableBoostingClassifier(
437
feature_names=data.feature_names,
438
interactions=5,
439
random_state=42
440
)
441
ebm.fit(X_train, y_train)
442
443
# Get explanations
444
global_exp = ebm.explain_global()
445
show(global_exp)
446
447
local_exp = ebm.explain_local(X_test[:5], y_test[:5])
448
show(local_exp)
449
```
450
451
### Comparing Multiple Interpretable Models
452
453
```python
454
from interpret.glassbox import (
455
ExplainableBoostingClassifier,
456
LogisticRegression,
457
ClassificationTree
458
)
459
from sklearn.metrics import accuracy_score
460
461
models = {
462
'EBM': ExplainableBoostingClassifier(random_state=42),
463
'Logistic': LogisticRegression(),
464
'Tree': ClassificationTree(max_depth=5, random_state=42)
465
}
466
467
for name, model in models.items():
468
model.fit(X_train, y_train)
469
pred = model.predict(X_test)
470
acc = accuracy_score(y_test, pred)
471
print(f"{name} Accuracy: {acc:.4f}")
472
473
# Show global explanations
474
show(model.explain_global(name=f"{name} Global"))
475
```