0
# Training Callbacks
1
2
Flexible training control through callback functions enabling early stopping, evaluation logging, parameter adjustment, and custom training behaviors. LightGBM's callback system supports both built-in callbacks for common use cases and custom callback implementations for specialized training requirements.
3
4
## Capabilities
5
6
### Early Stopping
7
8
Automatically stop training when validation metric stops improving to prevent overfitting and save computation time.
9
10
```python { .api }
11
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True, min_delta=0.0):
12
"""
13
Create early stopping callback for training.
14
15
Parameters:
16
- stopping_rounds: int - Number of rounds without improvement to trigger stopping
17
- first_metric_only: bool - Whether to use only the first metric for early stopping
18
- verbose: bool - Whether to print early stopping information
19
- min_delta: float - Minimum change in monitored quantity to qualify as improvement
20
21
Returns:
22
- callable: Early stopping callback function for use in train() or cv()
23
"""
24
25
class EarlyStopException(Exception):
26
"""
27
Exception raised for early stopping in training.
28
29
This exception can be raised from custom callbacks to trigger early stopping
30
with specific iteration and score information.
31
"""
32
33
def __init__(self, best_iteration, best_score):
34
"""
35
Create early stopping exception.
36
37
Parameters:
38
- best_iteration: int - Best iteration when early stopping occurred
39
- best_score: list - Best evaluation scores when stopping
40
"""
41
super().__init__()
42
self.best_iteration = best_iteration
43
self.best_score = best_score
44
```
45
46
### Evaluation Logging
47
48
Control the frequency and format of evaluation metric logging during training.
49
50
```python { .api }
51
def log_evaluation(period=1, show_stdv=True):
52
"""
53
Create evaluation logging callback for training progress monitoring.
54
55
Parameters:
56
- period: int - Evaluation logging frequency (log every N iterations)
57
- show_stdv: bool - Whether to show standard deviation in cross-validation
58
59
Returns:
60
- callable: Logging callback function for use in train() or cv()
61
"""
62
```
63
64
### Evaluation Recording
65
66
Record evaluation results in a dictionary for later analysis and visualization.
67
68
```python { .api }
69
def record_evaluation(eval_result):
70
"""
71
Create evaluation recording callback to store training history.
72
73
Parameters:
74
- eval_result: dict - Dictionary to store evaluation results
75
Will be populated with structure:
76
{
77
'dataset_name': {
78
'metric_name': [score1, score2, ...]
79
}
80
}
81
82
Returns:
83
- callable: Recording callback function for use in train() or cv()
84
"""
85
```
86
87
### Parameter Reset
88
89
Dynamically adjust training parameters during the training process.
90
91
```python { .api }
92
def reset_parameter(**kwargs):
93
"""
94
Create parameter reset callback for dynamic parameter adjustment.
95
96
Parameters:
97
- **kwargs: Parameter names and values to reset during training
98
Can include any LightGBM parameter (learning_rate, num_leaves, etc.)
99
100
Returns:
101
- callable: Parameter reset callback function for use in train() or cv()
102
"""
103
```
104
105
## Usage Examples
106
107
### Early Stopping Example
108
109
```python
110
import lightgbm as lgb
111
from sklearn.datasets import load_breast_cancer
112
from sklearn.model_selection import train_test_split
113
114
# Load data
115
X, y = load_breast_cancer(return_X_y=True)
116
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
117
118
# Prepare datasets
119
train_data = lgb.Dataset(X_train, label=y_train)
120
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
121
122
# Train with early stopping
123
model = lgb.train(
124
{
125
'objective': 'binary',
126
'metric': 'binary_logloss',
127
'num_leaves': 31,
128
'learning_rate': 0.05,
129
'verbose': -1
130
},
131
train_data,
132
num_boost_round=200,
133
valid_sets=[test_data],
134
valid_names=['test'],
135
callbacks=[
136
lgb.early_stopping(stopping_rounds=20, verbose=True),
137
lgb.log_evaluation(period=20)
138
]
139
)
140
141
print(f"Training stopped at iteration: {model.best_iteration}")
142
print(f"Best validation score: {model.best_score['test']['binary_logloss']:.4f}")
143
```
144
145
### Comprehensive Logging Example
146
147
```python
148
import lightgbm as lgb
149
import matplotlib.pyplot as plt
150
from sklearn.datasets import load_diabetes
151
from sklearn.model_selection import train_test_split
152
153
# Load data
154
X, y = load_diabetes(return_X_y=True)
155
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
156
157
# Prepare datasets
158
train_data = lgb.Dataset(X_train, label=y_train)
159
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
160
161
# Set up evaluation result recording
162
eval_result = {}
163
164
# Train with comprehensive logging
165
model = lgb.train(
166
{
167
'objective': 'regression',
168
'metric': ['rmse', 'mae', 'mape'],
169
'num_leaves': 31,
170
'learning_rate': 0.05,
171
'verbose': -1
172
},
173
train_data,
174
num_boost_round=150,
175
valid_sets=[train_data, test_data],
176
valid_names=['train', 'test'],
177
callbacks=[
178
lgb.record_evaluation(eval_result), # Record all metrics
179
lgb.log_evaluation(period=25, show_stdv=False), # Log every 25 iterations
180
lgb.early_stopping(stopping_rounds=15, first_metric_only=True)
181
]
182
)
183
184
# Analyze recorded results
185
print("Recorded metrics:")
186
for dataset in eval_result:
187
print(f" {dataset}:")
188
for metric in eval_result[dataset]:
189
final_score = eval_result[dataset][metric][-1]
190
print(f" {metric}: {final_score:.4f}")
191
192
# Plot training curves
193
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
194
195
for i, metric in enumerate(['rmse', 'mae', 'mape']):
196
ax = axes[i]
197
198
# Plot train and test curves
199
train_scores = eval_result['train'][metric]
200
test_scores = eval_result['test'][metric]
201
202
ax.plot(range(len(train_scores)), train_scores, label='Train', color='blue')
203
ax.plot(range(len(test_scores)), test_scores, label='Test', color='red')
204
205
# Mark best iteration
206
ax.axvline(x=model.best_iteration-1, color='green', linestyle='--',
207
label=f'Best ({model.best_iteration})')
208
209
ax.set_title(f'{metric.upper()} During Training')
210
ax.set_xlabel('Iteration')
211
ax.set_ylabel(metric.upper())
212
ax.legend()
213
ax.grid(True, alpha=0.3)
214
215
plt.tight_layout()
216
plt.show()
217
```
218
219
### Dynamic Parameter Adjustment Example
220
221
```python
222
import lightgbm as lgb
223
from sklearn.datasets import make_regression
224
225
# Generate data
226
X, y = make_regression(n_samples=10000, n_features=20, noise=0.1, random_state=42)
227
train_data = lgb.Dataset(X, label=y)
228
229
# Create learning rate scheduler
230
def learning_rate_scheduler(current_round, learning_rate_start=0.1, decay_rate=0.95, decay_step=20):
231
"""Custom learning rate scheduler."""
232
if current_round % decay_step == 0 and current_round > 0:
233
new_lr = learning_rate_start * (decay_rate ** (current_round // decay_step))
234
return {'learning_rate': new_lr}
235
return {}
236
237
# Train with dynamic parameter adjustment
238
eval_result = {}
239
model = lgb.train(
240
{
241
'objective': 'regression',
242
'metric': 'rmse',
243
'num_leaves': 31,
244
'learning_rate': 0.1, # Starting learning rate
245
'verbose': -1
246
},
247
train_data,
248
num_boost_round=100,
249
callbacks=[
250
lgb.record_evaluation(eval_result),
251
lgb.log_evaluation(period=20),
252
# Reset learning rate every 20 iterations
253
lgb.reset_parameter(learning_rate=lambda: 0.1 * (0.95 ** (model.current_iteration() // 20)))
254
]
255
)
256
257
print(f"Final RMSE: {eval_result['training']['rmse'][-1]:.4f}")
258
```
259
260
### Cross-Validation with Callbacks
261
262
```python
263
import lightgbm as lgb
264
import numpy as np
265
from sklearn.datasets import load_wine
266
267
# Load data
268
X, y = load_wine(return_X_y=True)
269
train_data = lgb.Dataset(X, label=y)
270
271
# Perform cross-validation with callbacks
272
cv_results = lgb.cv(
273
{
274
'objective': 'multiclass',
275
'num_class': 3,
276
'metric': 'multi_logloss',
277
'num_leaves': 31,
278
'learning_rate': 0.05,
279
'verbose': -1
280
},
281
train_data,
282
num_boost_round=100,
283
nfold=5,
284
stratified=True,
285
shuffle=True,
286
seed=42,
287
callbacks=[
288
lgb.log_evaluation(period=20, show_stdv=True), # Show std dev in CV
289
lgb.early_stopping(stopping_rounds=10)
290
]
291
)
292
293
print(f"CV Results:")
294
print(f"Best iteration: {len(cv_results['valid multi_logloss-mean'])}")
295
print(f"Best CV score: {cv_results['valid multi_logloss-mean'][-1]:.4f} Β± {cv_results['valid multi_logloss-stdv'][-1]:.4f}")
296
297
# Plot CV results with error bars
298
import matplotlib.pyplot as plt
299
300
iterations = range(len(cv_results['valid multi_logloss-mean']))
301
means = cv_results['valid multi_logloss-mean']
302
stds = cv_results['valid multi_logloss-stdv']
303
304
plt.figure(figsize=(10, 6))
305
plt.plot(iterations, means, color='blue', label='CV Mean')
306
plt.fill_between(iterations,
307
np.array(means) - np.array(stds),
308
np.array(means) + np.array(stds),
309
alpha=0.3, color='blue', label='CV Std Dev')
310
plt.xlabel('Iteration')
311
plt.ylabel('Multi Log Loss')
312
plt.title('Cross-Validation Results with Standard Deviation')
313
plt.legend()
314
plt.grid(True, alpha=0.3)
315
plt.show()
316
```
317
318
### Custom Callback Implementation
319
320
```python
321
import lightgbm as lgb
322
from sklearn.datasets import load_boston
323
324
def custom_metric_tracker(metric_threshold=0.1):
325
"""
326
Custom callback to track when metrics cross a threshold.
327
"""
328
def callback(env):
329
# env contains information about current training state
330
# env.model: current model
331
# env.params: training parameters
332
# env.iteration: current iteration
333
# env.begin_iteration: beginning iteration
334
# env.end_iteration: ending iteration
335
# env.evaluation_result_list: current evaluation results
336
337
if env.evaluation_result_list:
338
for eval_result in env.evaluation_result_list:
339
dataset_name, metric_name, metric_value, is_higher_better = eval_result
340
341
if metric_name == 'rmse' and metric_value < metric_threshold:
342
print(f"π― Metric threshold reached! RMSE: {metric_value:.4f} at iteration {env.iteration}")
343
344
# Continue training
345
return False
346
347
return callback
348
349
def custom_progress_bar(total_rounds, bar_length=50):
350
"""
351
Custom progress bar callback.
352
"""
353
def callback(env):
354
current = env.iteration - env.begin_iteration + 1
355
progress = current / total_rounds
356
filled_length = int(bar_length * progress)
357
358
bar = 'β' * filled_length + '-' * (bar_length - filled_length)
359
percent = progress * 100
360
361
print(f'\rProgress: |{bar}| {percent:.1f}% ({current}/{total_rounds})', end='')
362
363
if current == total_rounds:
364
print() # New line when complete
365
366
return False
367
368
return callback
369
370
# Load data
371
X, y = load_boston(return_X_y=True)
372
train_data = lgb.Dataset(X, label=y)
373
374
# Train with custom callbacks
375
model = lgb.train(
376
{
377
'objective': 'regression',
378
'metric': 'rmse',
379
'num_leaves': 31,
380
'learning_rate': 0.05,
381
'verbose': -1
382
},
383
train_data,
384
num_boost_round=100,
385
callbacks=[
386
custom_progress_bar(100), # Custom progress tracking
387
custom_metric_tracker(5.0), # Alert when RMSE < 5.0
388
lgb.log_evaluation(period=25) # Standard logging
389
]
390
)
391
392
print(f"\nTraining completed!")
393
print(f"Final RMSE: {model.eval_train()[0][2]:.4f}")
394
```
395
396
### Callback with sklearn Interface
397
398
```python
399
import lightgbm as lgb
400
from sklearn.datasets import load_iris
401
from sklearn.model_selection import train_test_split
402
403
# Load data
404
X, y = load_iris(return_X_y=True)
405
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
406
407
# Set up evaluation tracking
408
eval_result = {}
409
410
# Use callbacks with sklearn interface
411
model = lgb.LGBMClassifier(
412
objective='multiclass',
413
n_estimators=100,
414
learning_rate=0.05,
415
num_leaves=31,
416
random_state=42
417
)
418
419
# Fit with callbacks
420
model.fit(
421
X_train, y_train,
422
eval_set=[(X_train, y_train), (X_test, y_test)],
423
eval_names=['train', 'test'],
424
eval_metric='multi_logloss',
425
early_stopping_rounds=15,
426
verbose=True, # Equivalent to log_evaluation
427
callbacks=[
428
lgb.record_evaluation(eval_result)
429
]
430
)
431
432
# Access recorded results
433
print(f"Training completed at iteration: {model.best_iteration_}")
434
print(f"Best test score: {eval_result['test']['multi_logloss'][model.best_iteration_-1]:.4f}")
435
436
# Make predictions
437
predictions = model.predict(X_test)
438
probabilities = model.predict_proba(X_test)
439
440
print(f"Test accuracy: {(predictions == y_test).mean():.4f}")
441
```
442
443
## Advanced Callback Patterns
444
445
### Conditional Early Stopping
446
447
```python
448
def conditional_early_stopping(stopping_rounds, condition_func):
449
"""
450
Early stopping that only triggers when a condition is met.
451
"""
452
best_score = float('inf')
453
best_iteration = 0
454
current_rounds = 0
455
456
def callback(env):
457
nonlocal best_score, best_iteration, current_rounds
458
459
if env.evaluation_result_list:
460
current_score = env.evaluation_result_list[0][2] # First metric value
461
462
if current_score < best_score:
463
best_score = current_score
464
best_iteration = env.iteration
465
current_rounds = 0
466
else:
467
current_rounds += 1
468
469
# Only stop if condition is met AND stopping rounds exceeded
470
if condition_func(env) and current_rounds >= stopping_rounds:
471
print(f"Conditional early stopping at iteration {env.iteration}")
472
raise lgb.EarlyStopException(best_iteration, env.evaluation_result_list)
473
474
return False
475
476
return callback
477
478
# Example usage
479
def stop_condition(env):
480
"""Stop only if we've trained for at least 50 iterations."""
481
return env.iteration >= 50
482
483
# Use conditional early stopping
484
model = lgb.train(
485
params,
486
train_data,
487
num_boost_round=200,
488
valid_sets=[test_data],
489
callbacks=[
490
conditional_early_stopping(10, stop_condition),
491
lgb.log_evaluation(20)
492
]
493
)
494
```
495
496
### Multi-Metric Monitoring
497
498
```python
499
def multi_metric_monitor(metrics_config):
500
"""
501
Monitor multiple metrics with different thresholds and behaviors.
502
503
Args:
504
metrics_config: dict like {
505
'rmse': {'threshold': 5.0, 'action': 'alert'},
506
'mae': {'threshold': 3.0, 'action': 'stop'}
507
}
508
"""
509
def callback(env):
510
if env.evaluation_result_list:
511
for eval_result in env.evaluation_result_list:
512
dataset_name, metric_name, metric_value, is_higher_better = eval_result
513
514
if metric_name in metrics_config:
515
config = metrics_config[metric_name]
516
threshold = config['threshold']
517
action = config['action']
518
519
# Check threshold (assuming lower is better for this example)
520
if metric_value < threshold:
521
if action == 'alert':
522
print(f"π {metric_name} threshold reached: {metric_value:.4f}")
523
elif action == 'stop':
524
print(f"π Stopping due to {metric_name}: {metric_value:.4f}")
525
raise lgb.EarlyStopException(env.iteration, env.evaluation_result_list)
526
527
return False
528
529
return callback
530
531
# Example usage
532
metrics_config = {
533
'rmse': {'threshold': 4.0, 'action': 'alert'},
534
'mae': {'threshold': 3.0, 'action': 'stop'}
535
}
536
537
model = lgb.train(
538
{
539
'objective': 'regression',
540
'metric': ['rmse', 'mae'],
541
'verbose': -1
542
},
543
train_data,
544
num_boost_round=200,
545
valid_sets=[test_data],
546
callbacks=[
547
multi_metric_monitor(metrics_config),
548
lgb.log_evaluation(25)
549
]
550
)
551
```