or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

base-exceptions.mdcontrib.mddistributed.mdengine.mdhandlers.mdindex.mdmetrics.mdutils.md

handlers.mddocs/

0

# Handlers and Training Enhancement

1

2

Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking. PyTorch Ignite provides 40+ built-in handlers that plug into the event system to enhance training workflows.

3

4

## Capabilities

5

6

### Checkpointing

7

8

Model and training state checkpointing with flexible save strategies.

9

10

```python { .api }

11

class Checkpoint:

12

"""

13

Flexible checkpointing handler.

14

15

Parameters:

16

- to_save: dictionary of objects to save

17

- save_handler: handler for saving (DiskSaver, etc.)

18

- filename_prefix: prefix for checkpoint filenames

19

- score_function: function to compute checkpoint score

20

- score_name: name of the score metric

21

- n_saved: number of checkpoints to keep

22

- atomic: whether to use atomic saves

23

- require_empty: require empty directory

24

- archived: whether to archive old checkpoints

25

- greater_or_equal: score comparison direction

26

"""

27

def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...

28

29

class DiskSaver:

30

"""

31

Disk-based checkpoint saver.

32

33

Parameters:

34

- dirname: directory to save checkpoints

35

- atomic: whether to use atomic saves

36

- create_dir: whether to create directory if it doesn't exist

37

- require_empty: require empty directory

38

"""

39

def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): ...

40

41

class ModelCheckpoint:

42

"""

43

Model checkpoint handler (deprecated - use Checkpoint instead).

44

45

Parameters:

46

- dirname: directory to save checkpoints

47

- filename_prefix: prefix for checkpoint filenames

48

- score_function: function to compute checkpoint score

49

- score_name: name of the score metric

50

- n_saved: number of checkpoints to keep

51

- atomic: whether to use atomic saves

52

- require_empty: require empty directory

53

- create_dir: whether to create directory

54

- save_as_state_dict: save as state dict instead of full model

55

- global_step_transform: function to transform global step

56

"""

57

def __init__(self, dirname, filename_prefix, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None): ...

58

```

59

60

### Early Stopping

61

62

Early stopping based on validation metrics to prevent overfitting.

63

64

```python { .api }

65

class EarlyStopping:

66

"""

67

Early stopping handler to prevent overfitting.

68

69

Parameters:

70

- patience: number of events to wait before stopping

71

- score_function: function to compute stopping score

72

- trainer: trainer engine to stop

73

- min_delta: minimum change required to reset patience

74

- cumulative_delta: whether to use cumulative delta

75

"""

76

def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...

77

```

78

79

### Learning Rate Scheduling

80

81

Learning rate scheduling with various strategies and warmup support.

82

83

```python { .api }

84

class LRScheduler:

85

"""

86

Learning rate scheduler wrapper.

87

88

Parameters:

89

- lr_scheduler: PyTorch learning rate scheduler

90

- save_history: whether to save LR history

91

- **kwds: additional arguments

92

"""

93

def __init__(self, lr_scheduler, save_history=False, **kwds): ...

94

95

def create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end_value, warmup_duration, save_history=False):

96

"""

97

Create learning rate scheduler with warmup.

98

99

Parameters:

100

- lr_scheduler: base learning rate scheduler

101

- warmup_start_value: starting learning rate for warmup

102

- warmup_end_value: ending learning rate for warmup

103

- warmup_duration: duration of warmup phase

104

- save_history: whether to save LR history

105

106

Returns:

107

Combined scheduler with warmup

108

"""

109

110

class CosineAnnealingScheduler:

111

"""

112

Cosine annealing scheduler.

113

114

Parameters:

115

- optimizer: PyTorch optimizer

116

- param_name: parameter name to schedule

117

- start_value: starting parameter value

118

- end_value: ending parameter value

119

- cycle_size: size of one cycle

120

- cycle_mult: cycle size multiplier

121

- start_value_mult: start value multiplier per cycle

122

- end_value_mult: end value multiplier per cycle

123

- save_history: whether to save parameter history

124

"""

125

def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...

126

127

class LinearCyclicalScheduler:

128

"""

129

Linear cyclical scheduler.

130

131

Parameters:

132

- optimizer: PyTorch optimizer

133

- param_name: parameter name to schedule

134

- start_value: starting parameter value

135

- end_value: ending parameter value

136

- cycle_size: size of one cycle

137

- cycle_mult: cycle size multiplier

138

- start_value_mult: start value multiplier per cycle

139

- end_value_mult: end value multiplier per cycle

140

- save_history: whether to save parameter history

141

"""

142

def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...

143

144

class ConcatScheduler:

145

"""

146

Concatenated scheduler combining multiple schedulers.

147

148

Parameters:

149

- schedulers: list of (scheduler, duration) tuples

150

- durations: list of durations for each scheduler

151

- save_history: whether to save parameter history

152

"""

153

def __init__(self, schedulers, durations, save_history=False): ...

154

155

class PiecewiseLinear:

156

"""

157

Piecewise linear scheduler.

158

159

Parameters:

160

- optimizer: PyTorch optimizer

161

- param_name: parameter name to schedule

162

- milestones_values: list of (milestone, value) tuples

163

- save_history: whether to save parameter history

164

"""

165

def __init__(self, optimizer, param_name, milestones_values, save_history=False): ...

166

```

167

168

### Parameter Scheduling

169

170

General parameter scheduling framework for optimizers.

171

172

```python { .api }

173

class ParamScheduler:

174

"""

175

Base parameter scheduler class.

176

177

Parameters:

178

- optimizer: PyTorch optimizer

179

- param_name: parameter name to schedule

180

- save_history: whether to save parameter history

181

"""

182

def __init__(self, optimizer, param_name, save_history=False): ...

183

184

class ParamGroupScheduler:

185

"""

186

Parameter group scheduler for different parameter groups.

187

188

Parameters:

189

- schedulers: list of schedulers for each parameter group

190

- names: names for each parameter group

191

"""

192

def __init__(self, schedulers, names=None): ...

193

194

class StateParamScheduler:

195

"""

196

State-based parameter scheduler.

197

198

Parameters:

199

- param_scheduler: base parameter scheduler

200

- param_name: parameter name

201

- save_history: whether to save parameter history

202

"""

203

def __init__(self, param_scheduler, param_name, save_history=False): ...

204

205

class LambdaStateScheduler(StateParamScheduler):

206

"""

207

Lambda-based state parameter scheduler.

208

209

Parameters:

210

- lambda_func: lambda function for scheduling

211

- param_name: parameter name

212

- save_history: whether to save parameter history

213

"""

214

def __init__(self, lambda_func, param_name, save_history=False): ...

215

216

class ExpStateScheduler(StateParamScheduler):

217

"""

218

Exponential decay state parameter scheduler.

219

220

Parameters:

221

- gamma: exponential decay factor

222

- param_name: parameter name

223

- save_history: whether to save parameter history

224

"""

225

def __init__(self, gamma, param_name, save_history=False): ...

226

227

class StepStateScheduler(StateParamScheduler):

228

"""

229

Step-based state parameter scheduler.

230

231

Parameters:

232

- step_size: step size for scheduling

233

- gamma: decay factor

234

- param_name: parameter name

235

- save_history: whether to save parameter history

236

"""

237

def __init__(self, step_size, gamma, param_name, save_history=False): ...

238

239

class MultiStepStateScheduler(StateParamScheduler):

240

"""

241

Multi-step state parameter scheduler.

242

243

Parameters:

244

- milestones: list of milestones

245

- gamma: decay factor

246

- param_name: parameter name

247

- save_history: whether to save parameter history

248

"""

249

def __init__(self, milestones, gamma, param_name, save_history=False): ...

250

251

class PiecewiseLinearStateScheduler(StateParamScheduler):

252

"""

253

Piecewise linear state parameter scheduler.

254

255

Parameters:

256

- milestones_values: list of (milestone, value) tuples

257

- param_name: parameter name

258

- save_history: whether to save parameter history

259

"""

260

def __init__(self, milestones_values, param_name, save_history=False): ...

261

```

262

263

### Logging and Tracking

264

265

Integration with popular experiment tracking and logging frameworks.

266

267

```python { .api }

268

class TensorboardLogger:

269

"""

270

TensorBoard logging handler.

271

272

Parameters:

273

- log_dir: directory for TensorBoard logs

274

- **kwargs: additional arguments for SummaryWriter

275

"""

276

def __init__(self, log_dir=None, **kwargs): ...

277

278

def attach_output_handler(self, engine, event_name, tag, output_transform=None, metric_names=None, global_step_transform=None):

279

"""Attach output logging handler."""

280

281

def attach_opt_params_handler(self, engine, event_name, optimizer, param_name="lr"):

282

"""Attach optimizer parameter logging handler."""

283

284

class VisdomLogger:

285

"""

286

Visdom logging handler.

287

288

Parameters:

289

- server: Visdom server URL

290

- port: server port

291

- **kwargs: additional Visdom arguments

292

"""

293

def __init__(self, server=None, port=8097, **kwargs): ...

294

295

class MLflowLogger:

296

"""

297

MLflow experiment tracking.

298

299

Parameters:

300

- tracking_uri: MLflow tracking server URI

301

- experiment_name: name of the experiment

302

- run_name: name of the run

303

- artifact_location: artifact storage location

304

- **kwargs: additional MLflow arguments

305

"""

306

def __init__(self, tracking_uri=None, experiment_name=None, run_name=None, artifact_location=None, **kwargs): ...

307

308

class NeptuneLogger:

309

"""

310

Neptune experiment tracking.

311

312

Parameters:

313

- api_token: Neptune API token

314

- project_name: Neptune project name

315

- experiment_name: name of the experiment

316

- **kwargs: additional Neptune arguments

317

"""

318

def __init__(self, api_token=None, project_name=None, experiment_name=None, **kwargs): ...

319

320

class WandBLogger:

321

"""

322

Weights & Biases experiment tracking.

323

324

Parameters:

325

- project: W&B project name

326

- entity: W&B entity name

327

- config: configuration dictionary

328

- **kwargs: additional W&B arguments

329

"""

330

def __init__(self, project=None, entity=None, config=None, **kwargs): ...

331

332

class ClearMLLogger:

333

"""

334

ClearML experiment tracking.

335

336

Parameters:

337

- project_name: ClearML project name

338

- task_name: task name

339

- **kwargs: additional ClearML arguments

340

"""

341

def __init__(self, project_name=None, task_name=None, **kwargs): ...

342

343

class PolyaxonLogger:

344

"""

345

Polyaxon experiment tracking.

346

347

Parameters:

348

- **kwargs: Polyaxon configuration arguments

349

"""

350

def __init__(self, **kwargs): ...

351

```

352

353

### Progress and Timing

354

355

Progress bars and timing utilities for monitoring training.

356

357

```python { .api }

358

class ProgressBar:

359

"""

360

Progress bar for training monitoring.

361

362

Parameters:

363

- persist: whether to persist after completion

364

- bar_format: custom bar format string

365

- **tqdm_kwargs: additional tqdm arguments

366

"""

367

def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...

368

369

class Timer:

370

"""

371

Timer for measuring elapsed time.

372

373

Parameters:

374

- average: whether to compute running average

375

"""

376

def __init__(self, average=False): ...

377

378

def value(self):

379

"""Get current timer value."""

380

381

def reset(self):

382

"""Reset timer."""

383

384

def pause(self):

385

"""Pause timer."""

386

387

def resume(self):

388

"""Resume timer."""

389

390

class BasicTimeProfiler:

391

"""

392

Basic profiler for timing engine operations.

393

394

Parameters:

395

- dataflow_profiling: whether to profile data loading

396

"""

397

def __init__(self, dataflow_profiling=False): ...

398

399

def print_results(self, results_dict):

400

"""Print profiling results."""

401

402

class HandlersTimeProfiler:

403

"""

404

Profiler for timing handler execution.

405

"""

406

def __init__(self): ...

407

```

408

409

### Model Enhancement

410

411

Handlers for enhancing model training behavior.

412

413

```python { .api }

414

class GradientAccumulation:

415

"""

416

Gradient accumulation handler.

417

418

Parameters:

419

- accumulation_steps: number of steps to accumulate gradients

420

"""

421

def __init__(self, accumulation_steps): ...

422

423

class EMAHandler:

424

"""

425

Exponential Moving Average handler for model parameters.

426

427

Parameters:

428

- model: PyTorch model

429

- decay: decay factor for EMA

430

- device: device to store EMA parameters

431

"""

432

def __init__(self, model, decay=0.9999, device=None): ...

433

434

class FastaiLRFinder:

435

"""

436

Learning rate finder inspired by fastai.

437

438

Parameters:

439

- engine: training engine

440

- optimizer: PyTorch optimizer

441

- criterion: loss function

442

- device: device to run on

443

"""

444

def __init__(self, engine, optimizer, criterion, device=None): ...

445

446

def range_test(self, data_loader, start_lr=1e-7, end_lr=10, num_iter=100, step_mode="exp"):

447

"""Perform learning rate range test."""

448

449

class TerminateOnNan:

450

"""

451

Terminate training when NaN values are encountered.

452

"""

453

def __init__(self): ...

454

455

class TimeLimit:

456

"""

457

Terminate training after specified time limit.

458

459

Parameters:

460

- limit: time limit in seconds

461

"""

462

def __init__(self, limit): ...

463

```

464

465

### Base Classes

466

467

Base classes for creating custom handlers and loggers.

468

469

```python { .api }

470

class BaseLogger:

471

"""Base class for loggers."""

472

def __init__(self): ...

473

474

class BaseOptimizerParams:

475

"""Base class for optimizer parameter handlers."""

476

def __init__(self): ...

477

478

class BaseOutputTransform:

479

"""Base class for output transformations."""

480

def __init__(self): ...

481

```

482

483

### Utility Functions

484

485

Helper functions for handlers and training enhancement.

486

487

```python { .api }

488

def global_step_from_engine(engine):

489

"""

490

Get global step from engine state.

491

492

Parameters:

493

- engine: engine instance

494

495

Returns:

496

Global step number

497

"""

498

```

499

500

## Usage Examples

501

502

### Model Checkpointing

503

504

```python

505

from ignite.handlers import Checkpoint, DiskSaver

506

507

# Create checkpoint handler

508

to_save = {'model': model, 'optimizer': optimizer}

509

save_handler = DiskSaver('checkpoints', create_dir=True)

510

511

checkpoint = Checkpoint(

512

to_save,

513

save_handler,

514

filename_prefix='best',

515

score_function=lambda engine: -engine.state.metrics['loss'],

516

score_name='neg_loss',

517

n_saved=3

518

)

519

520

# Attach to evaluator

521

evaluator.add_event_handler(Events.COMPLETED, checkpoint)

522

```

523

524

### Early Stopping

525

526

```python

527

from ignite.handlers import EarlyStopping

528

529

# Create early stopping handler

530

early_stopping = EarlyStopping(

531

patience=10,

532

score_function=lambda engine: engine.state.metrics['accuracy'],

533

trainer=trainer

534

)

535

536

# Attach to evaluator

537

evaluator.add_event_handler(Events.COMPLETED, early_stopping)

538

```

539

540

### Learning Rate Scheduling

541

542

```python

543

from ignite.handlers import LRScheduler

544

from torch.optim.lr_scheduler import StepLR

545

546

# Create PyTorch scheduler

547

torch_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

548

549

# Wrap with Ignite scheduler

550

lr_scheduler = LRScheduler(torch_scheduler)

551

552

# Attach to trainer

553

trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)

554

555

# Access LR history if save_history=True

556

lr_scheduler = LRScheduler(torch_scheduler, save_history=True)

557

# ... after training

558

print(lr_scheduler.get_param())

559

```

560

561

### TensorBoard Logging

562

563

```python

564

from ignite.handlers import TensorboardLogger

565

566

# Create TensorBoard logger

567

tb_logger = TensorboardLogger(log_dir='tb_logs')

568

569

# Log training loss

570

tb_logger.attach_output_handler(

571

trainer,

572

event_name=Events.ITERATION_COMPLETED(every=100),

573

tag="training",

574

output_transform=lambda loss: {"loss": loss}

575

)

576

577

# Log validation metrics

578

tb_logger.attach_output_handler(

579

evaluator,

580

event_name=Events.COMPLETED,

581

tag="validation",

582

metric_names=["accuracy", "loss"],

583

global_step_transform=global_step_from_engine(trainer)

584

)

585

586

# Log learning rate

587

tb_logger.attach_opt_params_handler(

588

trainer,

589

event_name=Events.ITERATION_COMPLETED(every=100),

590

optimizer=optimizer,

591

param_name="lr"

592

)

593

594

# Don't forget to close

595

trainer.add_event_handler(Events.COMPLETED, lambda _: tb_logger.close())

596

```

597

598

### Progress Bar

599

600

```python

601

from ignite.handlers import ProgressBar

602

603

# Create progress bar

604

pbar = ProgressBar(persist=True)

605

606

# Attach to trainer

607

pbar.attach(trainer, metric_names=['loss'])

608

609

# Or with custom output transform

610

pbar.attach(trainer, output_transform=lambda x: {'loss': x})

611

```