0
# Callbacks
1
2
XGBoost provides a comprehensive callback system for monitoring and controlling the training process. Callbacks allow custom logic to be executed at different stages of training, including early stopping, learning rate scheduling, evaluation monitoring, and model checkpointing.
3
4
## Capabilities
5
6
### Base Callback Class
7
8
Abstract base class for creating custom training callbacks.
9
10
```python { .api }
11
class TrainingCallback:
12
def before_training(self, model):
13
"""
14
Called before training starts.
15
16
Parameters:
17
- model: The model instance
18
19
Returns:
20
model: The model instance (possibly modified)
21
"""
22
23
def after_training(self, model):
24
"""
25
Called after training completes.
26
27
Parameters:
28
- model: The trained model instance
29
30
Returns:
31
model: The model instance (possibly modified)
32
"""
33
34
def before_iteration(self, model, epoch, evals_log):
35
"""
36
Called before each training iteration.
37
38
Parameters:
39
- model: Current model instance
40
- epoch: Current epoch number
41
- evals_log: Evaluation results log
42
43
Returns:
44
bool: False to continue training, True to stop training
45
"""
46
47
def after_iteration(self, model, epoch, evals_log):
48
"""
49
Called after each training iteration.
50
51
Parameters:
52
- model: Current model instance
53
- epoch: Current epoch number
54
- evals_log: Evaluation results log
55
56
Returns:
57
bool: False to continue training, True to stop training
58
"""
59
```
60
61
### Callback Container
62
63
Container class for managing multiple callbacks during training.
64
65
```python { .api }
66
class CallbackContainer:
67
def __init__(self, callbacks, metric=None, is_maximize=False):
68
"""
69
Container for managing training callbacks.
70
71
Parameters:
72
- callbacks: List of TrainingCallback objects
73
- metric: Primary evaluation metric name
74
- is_maximize: Whether to maximize the metric (True) or minimize (False)
75
"""
76
77
def before_training(self, model):
78
"""Execute before_training for all callbacks."""
79
80
def after_training(self, model):
81
"""Execute after_training for all callbacks."""
82
83
def before_iteration(self, model, epoch, evals_log):
84
"""Execute before_iteration for all callbacks."""
85
86
def after_iteration(self, model, epoch, evals_log):
87
"""Execute after_iteration for all callbacks."""
88
```
89
90
### Early Stopping
91
92
Callback that stops training when evaluation metric stops improving.
93
94
```python { .api }
95
class EarlyStopping(TrainingCallback):
96
def __init__(
97
self,
98
rounds,
99
metric_name=None,
100
data_name=None,
101
maximize=False,
102
save_best=False,
103
min_delta=0.0
104
):
105
"""
106
Early stopping callback.
107
108
Parameters:
109
- rounds: Number of rounds to wait for improvement
110
- metric_name: Name of metric to monitor
111
- data_name: Name of dataset to monitor
112
- maximize: Whether to maximize metric (True) or minimize (False)
113
- save_best: Whether to save best model
114
- min_delta: Minimum change to qualify as improvement
115
"""
116
```
117
118
### Learning Rate Scheduler
119
120
Callback for scheduling learning rate changes during training.
121
122
```python { .api }
123
class LearningRateScheduler(TrainingCallback):
124
def __init__(self, learning_rates):
125
"""
126
Learning rate scheduler callback.
127
128
Parameters:
129
- learning_rates: Dictionary mapping epoch to learning rate,
130
or callable that takes epoch and returns learning rate
131
"""
132
```
133
134
### Evaluation Monitor
135
136
Callback for monitoring and logging evaluation metrics during training.
137
138
```python { .api }
139
class EvaluationMonitor(TrainingCallback):
140
def __init__(
141
self,
142
rank=0,
143
period=1,
144
show_stdv=True
145
):
146
"""
147
Evaluation monitoring callback.
148
149
Parameters:
150
- rank: Process rank for distributed training
151
- period: Frequency of logging (every N epochs)
152
- show_stdv: Whether to show standard deviation in CV
153
"""
154
```
155
156
### Training Checkpoint
157
158
Callback for saving model checkpoints during training.
159
160
```python { .api }
161
class TrainingCheckPoint(TrainingCallback):
162
def __init__(
163
self,
164
directory,
165
name="model",
166
as_pickle=False,
167
interval=1
168
):
169
"""
170
Training checkpoint callback.
171
172
Parameters:
173
- directory: Directory to save checkpoints
174
- name: Base name for checkpoint files
175
- as_pickle: Whether to save as pickle (True) or XGBoost format (False)
176
- interval: Checkpoint interval (every N epochs)
177
"""
178
```
179
180
## Usage Examples
181
182
### Basic Early Stopping
183
184
```python
185
import xgboost as xgb
186
from xgboost.callback import EarlyStopping
187
188
# Create early stopping callback
189
early_stop = EarlyStopping(
190
rounds=10,
191
metric_name='rmse',
192
data_name='eval',
193
maximize=False,
194
save_best=True
195
)
196
197
# Train with early stopping
198
dtrain = xgb.DMatrix(X_train, label=y_train)
199
deval = xgb.DMatrix(X_eval, label=y_eval)
200
201
model = xgb.train(
202
params,
203
dtrain,
204
num_boost_round=1000,
205
evals=[(deval, 'eval')],
206
callbacks=[early_stop],
207
verbose_eval=False
208
)
209
```
210
211
### Learning Rate Scheduling
212
213
```python
214
from xgboost.callback import LearningRateScheduler
215
216
# Define learning rate schedule
217
def lr_schedule(epoch):
218
if epoch < 50:
219
return 0.1
220
elif epoch < 100:
221
return 0.05
222
else:
223
return 0.01
224
225
# Create scheduler callback
226
lr_scheduler = LearningRateScheduler(lr_schedule)
227
228
# Train with learning rate scheduling
229
model = xgb.train(
230
params,
231
dtrain,
232
num_boost_round=150,
233
callbacks=[lr_scheduler]
234
)
235
```
236
237
### Multiple Callbacks
238
239
```python
240
from xgboost.callback import (
241
EarlyStopping,
242
EvaluationMonitor,
243
TrainingCheckPoint
244
)
245
246
# Create multiple callbacks
247
callbacks = [
248
EarlyStopping(rounds=10, save_best=True),
249
EvaluationMonitor(period=10),
250
TrainingCheckPoint(directory='./checkpoints', interval=50)
251
]
252
253
# Train with multiple callbacks
254
model = xgb.train(
255
params,
256
dtrain,
257
num_boost_round=1000,
258
evals=[(deval, 'eval')],
259
callbacks=callbacks
260
)
261
```
262
263
### Custom Callback
264
265
```python
266
from xgboost.callback import TrainingCallback
267
268
class CustomLoggingCallback(TrainingCallback):
269
def __init__(self, log_file):
270
self.log_file = log_file
271
272
def before_training(self, model):
273
with open(self.log_file, 'w') as f:
274
f.write("Training started\n")
275
return model
276
277
def after_iteration(self, model, epoch, evals_log):
278
if evals_log and 'eval' in evals_log:
279
metric_value = evals_log['eval']['rmse'][-1]
280
with open(self.log_file, 'a') as f:
281
f.write(f"Epoch {epoch}: RMSE = {metric_value}\n")
282
return False # Continue training
283
284
# Use custom callback
285
custom_logger = CustomLoggingCallback('training.log')
286
287
model = xgb.train(
288
params,
289
dtrain,
290
num_boost_round=100,
291
evals=[(deval, 'eval')],
292
callbacks=[custom_logger]
293
)
294
```
295
296
### Scikit-Learn Interface with Callbacks
297
298
```python
299
from xgboost import XGBRegressor
300
from xgboost.callback import EarlyStopping
301
302
# Create callback
303
early_stop = EarlyStopping(rounds=10)
304
305
# Use with sklearn interface
306
model = XGBRegressor(
307
n_estimators=1000,
308
early_stopping_rounds=10, # Alternative to callback
309
callbacks=[early_stop] # Or use callback directly
310
)
311
312
model.fit(
313
X_train, y_train,
314
eval_set=[(X_eval, y_eval)],
315
verbose=False
316
)
317
```