or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-training.mddistributed-computing.mdindex.mdsklearn-interface.mdtraining-callbacks.mdvisualization.md

training-callbacks.mddocs/

0

# Training Callbacks

1

2

Flexible training control through callback functions enabling early stopping, evaluation logging, parameter adjustment, and custom training behaviors. LightGBM's callback system supports both built-in callbacks for common use cases and custom callback implementations for specialized training requirements.

3

4

## Capabilities

5

6

### Early Stopping

7

8

Automatically stop training when validation metric stops improving to prevent overfitting and save computation time.

9

10

```python { .api }

11

def early_stopping(stopping_rounds, first_metric_only=False, verbose=True, min_delta=0.0):

12

"""

13

Create early stopping callback for training.

14

15

Parameters:

16

- stopping_rounds: int - Number of rounds without improvement to trigger stopping

17

- first_metric_only: bool - Whether to use only the first metric for early stopping

18

- verbose: bool - Whether to print early stopping information

19

- min_delta: float - Minimum change in monitored quantity to qualify as improvement

20

21

Returns:

22

- callable: Early stopping callback function for use in train() or cv()

23

"""

24

25

class EarlyStopException(Exception):

26

"""

27

Exception raised for early stopping in training.

28

29

This exception can be raised from custom callbacks to trigger early stopping

30

with specific iteration and score information.

31

"""

32

33

def __init__(self, best_iteration, best_score):

34

"""

35

Create early stopping exception.

36

37

Parameters:

38

- best_iteration: int - Best iteration when early stopping occurred

39

- best_score: list - Best evaluation scores when stopping

40

"""

41

super().__init__()

42

self.best_iteration = best_iteration

43

self.best_score = best_score

44

```

45

46

### Evaluation Logging

47

48

Control the frequency and format of evaluation metric logging during training.

49

50

```python { .api }

51

def log_evaluation(period=1, show_stdv=True):

52

"""

53

Create evaluation logging callback for training progress monitoring.

54

55

Parameters:

56

- period: int - Evaluation logging frequency (log every N iterations)

57

- show_stdv: bool - Whether to show standard deviation in cross-validation

58

59

Returns:

60

- callable: Logging callback function for use in train() or cv()

61

"""

62

```

63

64

### Evaluation Recording

65

66

Record evaluation results in a dictionary for later analysis and visualization.

67

68

```python { .api }

69

def record_evaluation(eval_result):

70

"""

71

Create evaluation recording callback to store training history.

72

73

Parameters:

74

- eval_result: dict - Dictionary to store evaluation results

75

Will be populated with structure:

76

{

77

'dataset_name': {

78

'metric_name': [score1, score2, ...]

79

}

80

}

81

82

Returns:

83

- callable: Recording callback function for use in train() or cv()

84

"""

85

```

86

87

### Parameter Reset

88

89

Dynamically adjust training parameters during the training process.

90

91

```python { .api }

92

def reset_parameter(**kwargs):

93

"""

94

Create parameter reset callback for dynamic parameter adjustment.

95

96

Parameters:

97

- **kwargs: Parameter names and values to reset during training

98

Can include any LightGBM parameter (learning_rate, num_leaves, etc.)

99

100

Returns:

101

- callable: Parameter reset callback function for use in train() or cv()

102

"""

103

```

104

105

## Usage Examples

106

107

### Early Stopping Example

108

109

```python

110

import lightgbm as lgb

111

from sklearn.datasets import load_breast_cancer

112

from sklearn.model_selection import train_test_split

113

114

# Load data

115

X, y = load_breast_cancer(return_X_y=True)

116

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

117

118

# Prepare datasets

119

train_data = lgb.Dataset(X_train, label=y_train)

120

test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

121

122

# Train with early stopping

123

model = lgb.train(

124

{

125

'objective': 'binary',

126

'metric': 'binary_logloss',

127

'num_leaves': 31,

128

'learning_rate': 0.05,

129

'verbose': -1

130

},

131

train_data,

132

num_boost_round=200,

133

valid_sets=[test_data],

134

valid_names=['test'],

135

callbacks=[

136

lgb.early_stopping(stopping_rounds=20, verbose=True),

137

lgb.log_evaluation(period=20)

138

]

139

)

140

141

print(f"Training stopped at iteration: {model.best_iteration}")

142

print(f"Best validation score: {model.best_score['test']['binary_logloss']:.4f}")

143

```

144

145

### Comprehensive Logging Example

146

147

```python

148

import lightgbm as lgb

149

import matplotlib.pyplot as plt

150

from sklearn.datasets import load_diabetes

151

from sklearn.model_selection import train_test_split

152

153

# Load data

154

X, y = load_diabetes(return_X_y=True)

155

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

156

157

# Prepare datasets

158

train_data = lgb.Dataset(X_train, label=y_train)

159

test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

160

161

# Set up evaluation result recording

162

eval_result = {}

163

164

# Train with comprehensive logging

165

model = lgb.train(

166

{

167

'objective': 'regression',

168

'metric': ['rmse', 'mae', 'mape'],

169

'num_leaves': 31,

170

'learning_rate': 0.05,

171

'verbose': -1

172

},

173

train_data,

174

num_boost_round=150,

175

valid_sets=[train_data, test_data],

176

valid_names=['train', 'test'],

177

callbacks=[

178

lgb.record_evaluation(eval_result), # Record all metrics

179

lgb.log_evaluation(period=25, show_stdv=False), # Log every 25 iterations

180

lgb.early_stopping(stopping_rounds=15, first_metric_only=True)

181

]

182

)

183

184

# Analyze recorded results

185

print("Recorded metrics:")

186

for dataset in eval_result:

187

print(f" {dataset}:")

188

for metric in eval_result[dataset]:

189

final_score = eval_result[dataset][metric][-1]

190

print(f" {metric}: {final_score:.4f}")

191

192

# Plot training curves

193

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

194

195

for i, metric in enumerate(['rmse', 'mae', 'mape']):

196

ax = axes[i]

197

198

# Plot train and test curves

199

train_scores = eval_result['train'][metric]

200

test_scores = eval_result['test'][metric]

201

202

ax.plot(range(len(train_scores)), train_scores, label='Train', color='blue')

203

ax.plot(range(len(test_scores)), test_scores, label='Test', color='red')

204

205

# Mark best iteration

206

ax.axvline(x=model.best_iteration-1, color='green', linestyle='--',

207

label=f'Best ({model.best_iteration})')

208

209

ax.set_title(f'{metric.upper()} During Training')

210

ax.set_xlabel('Iteration')

211

ax.set_ylabel(metric.upper())

212

ax.legend()

213

ax.grid(True, alpha=0.3)

214

215

plt.tight_layout()

216

plt.show()

217

```

218

219

### Dynamic Parameter Adjustment Example

220

221

```python

222

import lightgbm as lgb

223

from sklearn.datasets import make_regression

224

225

# Generate data

226

X, y = make_regression(n_samples=10000, n_features=20, noise=0.1, random_state=42)

227

train_data = lgb.Dataset(X, label=y)

228

229

# Create learning rate scheduler

230

def learning_rate_scheduler(current_round, learning_rate_start=0.1, decay_rate=0.95, decay_step=20):

231

"""Custom learning rate scheduler."""

232

if current_round % decay_step == 0 and current_round > 0:

233

new_lr = learning_rate_start * (decay_rate ** (current_round // decay_step))

234

return {'learning_rate': new_lr}

235

return {}

236

237

# Train with dynamic parameter adjustment

238

eval_result = {}

239

model = lgb.train(

240

{

241

'objective': 'regression',

242

'metric': 'rmse',

243

'num_leaves': 31,

244

'learning_rate': 0.1, # Starting learning rate

245

'verbose': -1

246

},

247

train_data,

248

num_boost_round=100,

249

callbacks=[

250

lgb.record_evaluation(eval_result),

251

lgb.log_evaluation(period=20),

252

# Reset learning rate every 20 iterations

253

lgb.reset_parameter(learning_rate=lambda: 0.1 * (0.95 ** (model.current_iteration() // 20)))

254

]

255

)

256

257

print(f"Final RMSE: {eval_result['training']['rmse'][-1]:.4f}")

258

```

259

260

### Cross-Validation with Callbacks

261

262

```python

263

import lightgbm as lgb

264

import numpy as np

265

from sklearn.datasets import load_wine

266

267

# Load data

268

X, y = load_wine(return_X_y=True)

269

train_data = lgb.Dataset(X, label=y)

270

271

# Perform cross-validation with callbacks

272

cv_results = lgb.cv(

273

{

274

'objective': 'multiclass',

275

'num_class': 3,

276

'metric': 'multi_logloss',

277

'num_leaves': 31,

278

'learning_rate': 0.05,

279

'verbose': -1

280

},

281

train_data,

282

num_boost_round=100,

283

nfold=5,

284

stratified=True,

285

shuffle=True,

286

seed=42,

287

callbacks=[

288

lgb.log_evaluation(period=20, show_stdv=True), # Show std dev in CV

289

lgb.early_stopping(stopping_rounds=10)

290

]

291

)

292

293

print(f"CV Results:")

294

print(f"Best iteration: {len(cv_results['valid multi_logloss-mean'])}")

295

print(f"Best CV score: {cv_results['valid multi_logloss-mean'][-1]:.4f} Β± {cv_results['valid multi_logloss-stdv'][-1]:.4f}")

296

297

# Plot CV results with error bars

298

import matplotlib.pyplot as plt

299

300

iterations = range(len(cv_results['valid multi_logloss-mean']))

301

means = cv_results['valid multi_logloss-mean']

302

stds = cv_results['valid multi_logloss-stdv']

303

304

plt.figure(figsize=(10, 6))

305

plt.plot(iterations, means, color='blue', label='CV Mean')

306

plt.fill_between(iterations,

307

np.array(means) - np.array(stds),

308

np.array(means) + np.array(stds),

309

alpha=0.3, color='blue', label='CV Std Dev')

310

plt.xlabel('Iteration')

311

plt.ylabel('Multi Log Loss')

312

plt.title('Cross-Validation Results with Standard Deviation')

313

plt.legend()

314

plt.grid(True, alpha=0.3)

315

plt.show()

316

```

317

318

### Custom Callback Implementation

319

320

```python

321

import lightgbm as lgb

322

from sklearn.datasets import load_boston

323

324

def custom_metric_tracker(metric_threshold=0.1):

325

"""

326

Custom callback to track when metrics cross a threshold.

327

"""

328

def callback(env):

329

# env contains information about current training state

330

# env.model: current model

331

# env.params: training parameters

332

# env.iteration: current iteration

333

# env.begin_iteration: beginning iteration

334

# env.end_iteration: ending iteration

335

# env.evaluation_result_list: current evaluation results

336

337

if env.evaluation_result_list:

338

for eval_result in env.evaluation_result_list:

339

dataset_name, metric_name, metric_value, is_higher_better = eval_result

340

341

if metric_name == 'rmse' and metric_value < metric_threshold:

342

print(f"🎯 Metric threshold reached! RMSE: {metric_value:.4f} at iteration {env.iteration}")

343

344

# Continue training

345

return False

346

347

return callback

348

349

def custom_progress_bar(total_rounds, bar_length=50):

350

"""

351

Custom progress bar callback.

352

"""

353

def callback(env):

354

current = env.iteration - env.begin_iteration + 1

355

progress = current / total_rounds

356

filled_length = int(bar_length * progress)

357

358

bar = 'β–ˆ' * filled_length + '-' * (bar_length - filled_length)

359

percent = progress * 100

360

361

print(f'\rProgress: |{bar}| {percent:.1f}% ({current}/{total_rounds})', end='')

362

363

if current == total_rounds:

364

print() # New line when complete

365

366

return False

367

368

return callback

369

370

# Load data

371

X, y = load_boston(return_X_y=True)

372

train_data = lgb.Dataset(X, label=y)

373

374

# Train with custom callbacks

375

model = lgb.train(

376

{

377

'objective': 'regression',

378

'metric': 'rmse',

379

'num_leaves': 31,

380

'learning_rate': 0.05,

381

'verbose': -1

382

},

383

train_data,

384

num_boost_round=100,

385

callbacks=[

386

custom_progress_bar(100), # Custom progress tracking

387

custom_metric_tracker(5.0), # Alert when RMSE < 5.0

388

lgb.log_evaluation(period=25) # Standard logging

389

]

390

)

391

392

print(f"\nTraining completed!")

393

print(f"Final RMSE: {model.eval_train()[0][2]:.4f}")

394

```

395

396

### Callback with sklearn Interface

397

398

```python

399

import lightgbm as lgb

400

from sklearn.datasets import load_iris

401

from sklearn.model_selection import train_test_split

402

403

# Load data

404

X, y = load_iris(return_X_y=True)

405

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

406

407

# Set up evaluation tracking

408

eval_result = {}

409

410

# Use callbacks with sklearn interface

411

model = lgb.LGBMClassifier(

412

objective='multiclass',

413

n_estimators=100,

414

learning_rate=0.05,

415

num_leaves=31,

416

random_state=42

417

)

418

419

# Fit with callbacks

420

model.fit(

421

X_train, y_train,

422

eval_set=[(X_train, y_train), (X_test, y_test)],

423

eval_names=['train', 'test'],

424

eval_metric='multi_logloss',

425

early_stopping_rounds=15,

426

verbose=True, # Equivalent to log_evaluation

427

callbacks=[

428

lgb.record_evaluation(eval_result)

429

]

430

)

431

432

# Access recorded results

433

print(f"Training completed at iteration: {model.best_iteration_}")

434

print(f"Best test score: {eval_result['test']['multi_logloss'][model.best_iteration_-1]:.4f}")

435

436

# Make predictions

437

predictions = model.predict(X_test)

438

probabilities = model.predict_proba(X_test)

439

440

print(f"Test accuracy: {(predictions == y_test).mean():.4f}")

441

```

442

443

## Advanced Callback Patterns

444

445

### Conditional Early Stopping

446

447

```python

448

def conditional_early_stopping(stopping_rounds, condition_func):

449

"""

450

Early stopping that only triggers when a condition is met.

451

"""

452

best_score = float('inf')

453

best_iteration = 0

454

current_rounds = 0

455

456

def callback(env):

457

nonlocal best_score, best_iteration, current_rounds

458

459

if env.evaluation_result_list:

460

current_score = env.evaluation_result_list[0][2] # First metric value

461

462

if current_score < best_score:

463

best_score = current_score

464

best_iteration = env.iteration

465

current_rounds = 0

466

else:

467

current_rounds += 1

468

469

# Only stop if condition is met AND stopping rounds exceeded

470

if condition_func(env) and current_rounds >= stopping_rounds:

471

print(f"Conditional early stopping at iteration {env.iteration}")

472

raise lgb.EarlyStopException(best_iteration, env.evaluation_result_list)

473

474

return False

475

476

return callback

477

478

# Example usage

479

def stop_condition(env):

480

"""Stop only if we've trained for at least 50 iterations."""

481

return env.iteration >= 50

482

483

# Use conditional early stopping

484

model = lgb.train(

485

params,

486

train_data,

487

num_boost_round=200,

488

valid_sets=[test_data],

489

callbacks=[

490

conditional_early_stopping(10, stop_condition),

491

lgb.log_evaluation(20)

492

]

493

)

494

```

495

496

### Multi-Metric Monitoring

497

498

```python

499

def multi_metric_monitor(metrics_config):

500

"""

501

Monitor multiple metrics with different thresholds and behaviors.

502

503

Args:

504

metrics_config: dict like {

505

'rmse': {'threshold': 5.0, 'action': 'alert'},

506

'mae': {'threshold': 3.0, 'action': 'stop'}

507

}

508

"""

509

def callback(env):

510

if env.evaluation_result_list:

511

for eval_result in env.evaluation_result_list:

512

dataset_name, metric_name, metric_value, is_higher_better = eval_result

513

514

if metric_name in metrics_config:

515

config = metrics_config[metric_name]

516

threshold = config['threshold']

517

action = config['action']

518

519

# Check threshold (assuming lower is better for this example)

520

if metric_value < threshold:

521

if action == 'alert':

522

print(f"πŸ”” {metric_name} threshold reached: {metric_value:.4f}")

523

elif action == 'stop':

524

print(f"πŸ›‘ Stopping due to {metric_name}: {metric_value:.4f}")

525

raise lgb.EarlyStopException(env.iteration, env.evaluation_result_list)

526

527

return False

528

529

return callback

530

531

# Example usage

532

metrics_config = {

533

'rmse': {'threshold': 4.0, 'action': 'alert'},

534

'mae': {'threshold': 3.0, 'action': 'stop'}

535

}

536

537

model = lgb.train(

538

{

539

'objective': 'regression',

540

'metric': ['rmse', 'mae'],

541

'verbose': -1

542

},

543

train_data,

544

num_boost_round=200,

545

valid_sets=[test_data],

546

callbacks=[

547

multi_metric_monitor(metrics_config),

548

lgb.log_evaluation(25)

549

]

550

)

551

```