or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcomposition.mdconfig-schema.mdconfig-store.mderrors.mdindex.mdinitialization.mdmain-decorator.mdtypes.mdutilities.md

callbacks.mddocs/

0

# Callbacks

1

2

Experimental callback API for hooking into Hydra's execution lifecycle. Callbacks enable custom logic at different stages of application execution including run start/end, multirun events, and individual job events.

3

4

## Capabilities

5

6

### Callback Base Class

7

8

Base class for implementing custom callbacks that respond to Hydra execution events.

9

10

```python { .api }

11

class Callback:

12

"""Base class for Hydra callbacks."""

13

14

def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:

15

"""

16

Called in RUN mode before job/application code starts.

17

18

Parameters:

19

- config: Composed configuration with overrides applied

20

- **kwargs: Additional context (future extensibility)

21

22

Note: Some hydra.runtime configs may not be populated yet.

23

"""

24

25

def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:

26

"""

27

Called in RUN mode after job/application code returns.

28

29

Parameters:

30

- config: The configuration used for the run

31

- **kwargs: Additional context

32

"""

33

34

def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:

35

"""

36

Called in MULTIRUN mode before any job starts.

37

38

Parameters:

39

- config: Base configuration before parameter sweeps

40

- **kwargs: Additional context

41

42

Note: When using a launcher, this executes on local machine

43

before any Sweeper/Launcher is initialized.

44

"""

45

46

def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:

47

"""

48

Called in MULTIRUN mode after all jobs return.

49

50

Parameters:

51

- config: Base configuration

52

- **kwargs: Additional context

53

54

Note: When using a launcher, this executes on local machine.

55

"""

56

57

def on_job_start(

58

self,

59

config: DictConfig,

60

*,

61

task_function: TaskFunction,

62

**kwargs: Any

63

) -> None:

64

"""

65

Called in both RUN and MULTIRUN modes for each Hydra job.

66

67

Parameters:

68

- config: Configuration for this specific job

69

- task_function: The function decorated with @hydra.main

70

- **kwargs: Additional context

71

72

Note: In remote launching, this executes on the remote server

73

along with your application code.

74

"""

75

76

def on_job_end(

77

self,

78

config: DictConfig,

79

job_return: JobReturn,

80

**kwargs: Any

81

) -> None:

82

"""

83

Called in both RUN and MULTIRUN modes after each job completes.

84

85

Parameters:

86

- config: Configuration for the completed job

87

- job_return: Information about job execution and results

88

- **kwargs: Additional context

89

90

Note: In remote launching, this executes on the remote server

91

after your application code.

92

"""

93

```

94

95

## Usage Examples

96

97

### Basic Callback Implementation

98

99

```python

100

from hydra.experimental.callback import Callback

101

from omegaconf import DictConfig

102

from hydra.types import TaskFunction

103

from hydra.core.utils import JobReturn

104

import logging

105

from typing import Any

106

107

class LoggingCallback(Callback):

108

"""Simple callback that logs execution events."""

109

110

def __init__(self):

111

self.logger = logging.getLogger(__name__)

112

113

def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:

114

self.logger.info(f"Starting run with config: {config.get('name', 'unnamed')}")

115

116

def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:

117

self.logger.info("Run completed")

118

119

def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:

120

self.logger.info("Starting multirun")

121

122

def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:

123

self.logger.info("Multirun completed")

124

125

def on_job_start(

126

self,

127

config: DictConfig,

128

*,

129

task_function: TaskFunction,

130

**kwargs: Any

131

) -> None:

132

job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')

133

self.logger.info(f"Starting job: {job_name}")

134

135

def on_job_end(

136

self,

137

config: DictConfig,

138

job_return: JobReturn,

139

**kwargs: Any

140

) -> None:

141

job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')

142

status = "SUCCESS" if job_return.status == JobReturn.Status.COMPLETED else "FAILED"

143

self.logger.info(f"Job {job_name} finished with status: {status}")

144

```

145

146

### Performance Monitoring Callback

147

148

```python

149

import time

150

from typing import Dict, Any

151

from hydra.experimental.callback import Callback

152

from omegaconf import DictConfig

153

from hydra.core.utils import JobReturn

154

155

class PerformanceCallback(Callback):

156

"""Callback for monitoring execution performance."""

157

158

def __init__(self):

159

self.start_times: Dict[str, float] = {}

160

self.metrics: Dict[str, Any] = {}

161

162

def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:

163

self.start_times['run'] = time.time()

164

print("Performance monitoring started")

165

166

def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:

167

duration = time.time() - self.start_times['run']

168

print(f"Total execution time: {duration:.2f} seconds")

169

170

def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:

171

self.start_times['multirun'] = time.time()

172

self.metrics['jobs_completed'] = 0

173

print("Multirun performance monitoring started")

174

175

def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:

176

total_duration = time.time() - self.start_times['multirun']

177

jobs = self.metrics.get('jobs_completed', 0)

178

avg_job_time = total_duration / jobs if jobs > 0 else 0

179

180

print(f"Multirun completed in {total_duration:.2f} seconds")

181

print(f"Jobs completed: {jobs}")

182

print(f"Average job time: {avg_job_time:.2f} seconds")

183

184

def on_job_start(

185

self,

186

config: DictConfig,

187

*,

188

task_function: TaskFunction,

189

**kwargs: Any

190

) -> None:

191

job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')

192

self.start_times[f'job_{job_id}'] = time.time()

193

194

def on_job_end(

195

self,

196

config: DictConfig,

197

job_return: JobReturn,

198

**kwargs: Any

199

) -> None:

200

job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')

201

start_key = f'job_{job_id}'

202

203

if start_key in self.start_times:

204

duration = time.time() - self.start_times[start_key]

205

print(f"Job {job_id} completed in {duration:.2f} seconds")

206

del self.start_times[start_key]

207

208

self.metrics['jobs_completed'] = self.metrics.get('jobs_completed', 0) + 1

209

```

210

211

### Configuration Validation Callback

212

213

```python

214

from hydra.experimental.callback import Callback

215

from omegaconf import DictConfig

216

from typing import Any

217

218

class ValidationCallback(Callback):

219

"""Callback for validating configurations."""

220

221

def __init__(self, required_keys: list = None):

222

self.required_keys = required_keys or []

223

224

def on_job_start(

225

self,

226

config: DictConfig,

227

*,

228

task_function: TaskFunction,

229

**kwargs: Any

230

) -> None:

231

"""Validate configuration before job execution."""

232

233

# Check required keys

234

for key in self.required_keys:

235

if key not in config:

236

raise ValueError(f"Required configuration key missing: {key}")

237

238

# Custom validation logic

239

if hasattr(config, 'database') and config.database:

240

if config.database.get('port', 0) <= 0:

241

raise ValueError("Database port must be positive")

242

243

print("Configuration validation passed")

244

245

def on_job_end(

246

self,

247

config: DictConfig,

248

job_return: JobReturn,

249

**kwargs: Any

250

) -> None:

251

"""Log job completion status."""

252

if job_return.status == JobReturn.Status.FAILED:

253

print(f"Job failed with configuration: {config}")

254

```

255

256

### Results Aggregation Callback

257

258

```python

259

import json

260

from pathlib import Path

261

from typing import List, Any

262

from hydra.experimental.callback import Callback

263

from omegaconf import DictConfig

264

from hydra.core.utils import JobReturn

265

266

class ResultsAggregatorCallback(Callback):

267

"""Callback for aggregating results from multirun experiments."""

268

269

def __init__(self, output_file: str = "results.json"):

270

self.output_file = output_file

271

self.results: List[Dict[str, Any]] = []

272

273

def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:

274

self.results = [] # Reset results for new multirun

275

print("Results aggregation started")

276

277

def on_job_end(

278

self,

279

config: DictConfig,

280

job_return: JobReturn,

281

**kwargs: Any

282

) -> None:

283

"""Collect results from each job."""

284

285

result = {

286

'job_id': config.get('hydra', {}).get('job', {}).get('id'),

287

'config': dict(config), # Convert to regular dict for JSON serialization

288

'status': str(job_return.status),

289

'return_value': job_return.return_value,

290

'hydra_cfg': dict(config.get('hydra', {}))

291

}

292

293

self.results.append(result)

294

print(f"Collected result from job {result['job_id']}")

295

296

def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:

297

"""Save aggregated results to file."""

298

299

output_path = Path(self.output_file)

300

with open(output_path, 'w') as f:

301

json.dump(self.results, f, indent=2, default=str)

302

303

print(f"Results saved to {output_path}")

304

print(f"Total jobs processed: {len(self.results)}")

305

```

306

307

### Callback Registration and Configuration

308

309

```python

310

# Callbacks are typically configured through Hydra's configuration system

311

# or registered programmatically

312

313

from hydra import main, initialize, compose

314

from hydra.core.config_store import ConfigStore

315

from dataclasses import dataclass, field

316

from typing import List

317

318

@dataclass

319

class CallbackConfig:

320

_target_: str

321

# Additional callback-specific parameters

322

323

@dataclass

324

class AppConfig:

325

name: str = "MyApp"

326

callbacks: List[CallbackConfig] = field(default_factory=list)

327

328

# Register callback configs

329

cs = ConfigStore.instance()

330

cs.store(name="logging_callback", node=CallbackConfig(

331

_target_="__main__.LoggingCallback"

332

), group="callbacks")

333

334

cs.store(name="performance_callback", node=CallbackConfig(

335

_target_="__main__.PerformanceCallback"

336

), group="callbacks")

337

338

# Use in configuration files:

339

# config.yaml:

340

# defaults:

341

# - callbacks: [logging_callback, performance_callback]

342

```

343

344

### Integration with Hydra Application

345

346

```python

347

from hydra import main

348

from omegaconf import DictConfig

349

350

# Callbacks are automatically invoked when registered through configuration

351

@main(version_base=None, config_path="conf", config_name="config")

352

def my_app(cfg: DictConfig) -> str:

353

"""Application function with callback integration."""

354

355

print(f"Running application: {cfg.name}")

356

357

# Simulate some work

358

import time

359

time.sleep(1)

360

361

result = f"Processed {cfg.get('items', 0)} items"

362

print(result)

363

364

return result # Return value available in on_job_end callback

365

366

if __name__ == "__main__":

367

my_app()

368

```

369

370

### Advanced Callback Patterns

371

372

```python

373

from hydra.experimental.callback import Callback

374

from omegaconf import DictConfig

375

from typing import Any, Dict

376

import threading

377

378

class ThreadSafeCallback(Callback):

379

"""Thread-safe callback for concurrent job execution."""

380

381

def __init__(self):

382

self._lock = threading.Lock()

383

self._shared_state: Dict[str, Any] = {}

384

385

def on_job_start(

386

self,

387

config: DictConfig,

388

*,

389

task_function: TaskFunction,

390

**kwargs: Any

391

) -> None:

392

with self._lock:

393

job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')

394

self._shared_state[job_id] = {'status': 'running', 'start_time': time.time()}

395

396

def on_job_end(

397

self,

398

config: DictConfig,

399

job_return: JobReturn,

400

**kwargs: Any

401

) -> None:

402

with self._lock:

403

job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')

404

if job_id in self._shared_state:

405

self._shared_state[job_id].update({

406

'status': 'completed',

407

'end_time': time.time(),

408

'success': job_return.status == JobReturn.Status.COMPLETED

409

})

410

411

class ConditionalCallback(Callback):

412

"""Callback that only executes under certain conditions."""

413

414

def __init__(self, condition_key: str, condition_value: Any):

415

self.condition_key = condition_key

416

self.condition_value = condition_value

417

418

def _should_execute(self, config: DictConfig) -> bool:

419

"""Check if callback should execute based on configuration."""

420

from omegaconf import OmegaConf

421

422

try:

423

actual_value = OmegaConf.select(config, self.condition_key)

424

return actual_value == self.condition_value

425

except:

426

return False

427

428

def on_job_start(

429

self,

430

config: DictConfig,

431

*,

432

task_function: TaskFunction,

433

**kwargs: Any

434

) -> None:

435

if self._should_execute(config):

436

print(f"Conditional callback triggered for {self.condition_key}={self.condition_value}")

437

```