0
# Callbacks
1
2
Experimental callback API for hooking into Hydra's execution lifecycle. Callbacks enable custom logic at different stages of application execution including run start/end, multirun events, and individual job events.
3
4
## Capabilities
5
6
### Callback Base Class
7
8
Base class for implementing custom callbacks that respond to Hydra execution events.
9
10
```python { .api }
11
class Callback:
12
"""Base class for Hydra callbacks."""
13
14
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
15
"""
16
Called in RUN mode before job/application code starts.
17
18
Parameters:
19
- config: Composed configuration with overrides applied
20
- **kwargs: Additional context (future extensibility)
21
22
Note: Some hydra.runtime configs may not be populated yet.
23
"""
24
25
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
26
"""
27
Called in RUN mode after job/application code returns.
28
29
Parameters:
30
- config: The configuration used for the run
31
- **kwargs: Additional context
32
"""
33
34
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
35
"""
36
Called in MULTIRUN mode before any job starts.
37
38
Parameters:
39
- config: Base configuration before parameter sweeps
40
- **kwargs: Additional context
41
42
Note: When using a launcher, this executes on local machine
43
before any Sweeper/Launcher is initialized.
44
"""
45
46
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
47
"""
48
Called in MULTIRUN mode after all jobs return.
49
50
Parameters:
51
- config: Base configuration
52
- **kwargs: Additional context
53
54
Note: When using a launcher, this executes on local machine.
55
"""
56
57
def on_job_start(
58
self,
59
config: DictConfig,
60
*,
61
task_function: TaskFunction,
62
**kwargs: Any
63
) -> None:
64
"""
65
Called in both RUN and MULTIRUN modes for each Hydra job.
66
67
Parameters:
68
- config: Configuration for this specific job
69
- task_function: The function decorated with @hydra.main
70
- **kwargs: Additional context
71
72
Note: In remote launching, this executes on the remote server
73
along with your application code.
74
"""
75
76
def on_job_end(
77
self,
78
config: DictConfig,
79
job_return: JobReturn,
80
**kwargs: Any
81
) -> None:
82
"""
83
Called in both RUN and MULTIRUN modes after each job completes.
84
85
Parameters:
86
- config: Configuration for the completed job
87
- job_return: Information about job execution and results
88
- **kwargs: Additional context
89
90
Note: In remote launching, this executes on the remote server
91
after your application code.
92
"""
93
```
94
95
## Usage Examples
96
97
### Basic Callback Implementation
98
99
```python
100
from hydra.experimental.callback import Callback
101
from omegaconf import DictConfig
102
from hydra.types import TaskFunction
103
from hydra.core.utils import JobReturn
104
import logging
105
from typing import Any
106
107
class LoggingCallback(Callback):
108
"""Simple callback that logs execution events."""
109
110
def __init__(self):
111
self.logger = logging.getLogger(__name__)
112
113
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
114
self.logger.info(f"Starting run with config: {config.get('name', 'unnamed')}")
115
116
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
117
self.logger.info("Run completed")
118
119
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
120
self.logger.info("Starting multirun")
121
122
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
123
self.logger.info("Multirun completed")
124
125
def on_job_start(
126
self,
127
config: DictConfig,
128
*,
129
task_function: TaskFunction,
130
**kwargs: Any
131
) -> None:
132
job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
133
self.logger.info(f"Starting job: {job_name}")
134
135
def on_job_end(
136
self,
137
config: DictConfig,
138
job_return: JobReturn,
139
**kwargs: Any
140
) -> None:
141
job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
142
status = "SUCCESS" if job_return.status == JobReturn.Status.COMPLETED else "FAILED"
143
self.logger.info(f"Job {job_name} finished with status: {status}")
144
```
145
146
### Performance Monitoring Callback
147
148
```python
149
import time
150
from typing import Dict, Any
151
from hydra.experimental.callback import Callback
152
from omegaconf import DictConfig
153
from hydra.core.utils import JobReturn
154
155
class PerformanceCallback(Callback):
156
"""Callback for monitoring execution performance."""
157
158
def __init__(self):
159
self.start_times: Dict[str, float] = {}
160
self.metrics: Dict[str, Any] = {}
161
162
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
163
self.start_times['run'] = time.time()
164
print("Performance monitoring started")
165
166
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
167
duration = time.time() - self.start_times['run']
168
print(f"Total execution time: {duration:.2f} seconds")
169
170
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
171
self.start_times['multirun'] = time.time()
172
self.metrics['jobs_completed'] = 0
173
print("Multirun performance monitoring started")
174
175
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
176
total_duration = time.time() - self.start_times['multirun']
177
jobs = self.metrics.get('jobs_completed', 0)
178
avg_job_time = total_duration / jobs if jobs > 0 else 0
179
180
print(f"Multirun completed in {total_duration:.2f} seconds")
181
print(f"Jobs completed: {jobs}")
182
print(f"Average job time: {avg_job_time:.2f} seconds")
183
184
def on_job_start(
185
self,
186
config: DictConfig,
187
*,
188
task_function: TaskFunction,
189
**kwargs: Any
190
) -> None:
191
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
192
self.start_times[f'job_{job_id}'] = time.time()
193
194
def on_job_end(
195
self,
196
config: DictConfig,
197
job_return: JobReturn,
198
**kwargs: Any
199
) -> None:
200
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
201
start_key = f'job_{job_id}'
202
203
if start_key in self.start_times:
204
duration = time.time() - self.start_times[start_key]
205
print(f"Job {job_id} completed in {duration:.2f} seconds")
206
del self.start_times[start_key]
207
208
self.metrics['jobs_completed'] = self.metrics.get('jobs_completed', 0) + 1
209
```
210
211
### Configuration Validation Callback
212
213
```python
214
from hydra.experimental.callback import Callback
215
from omegaconf import DictConfig
216
from typing import Any
217
218
class ValidationCallback(Callback):
219
"""Callback for validating configurations."""
220
221
def __init__(self, required_keys: list = None):
222
self.required_keys = required_keys or []
223
224
def on_job_start(
225
self,
226
config: DictConfig,
227
*,
228
task_function: TaskFunction,
229
**kwargs: Any
230
) -> None:
231
"""Validate configuration before job execution."""
232
233
# Check required keys
234
for key in self.required_keys:
235
if key not in config:
236
raise ValueError(f"Required configuration key missing: {key}")
237
238
# Custom validation logic
239
if hasattr(config, 'database') and config.database:
240
if config.database.get('port', 0) <= 0:
241
raise ValueError("Database port must be positive")
242
243
print("Configuration validation passed")
244
245
def on_job_end(
246
self,
247
config: DictConfig,
248
job_return: JobReturn,
249
**kwargs: Any
250
) -> None:
251
"""Log job completion status."""
252
if job_return.status == JobReturn.Status.FAILED:
253
print(f"Job failed with configuration: {config}")
254
```
255
256
### Results Aggregation Callback
257
258
```python
259
import json
260
from pathlib import Path
261
from typing import List, Any
262
from hydra.experimental.callback import Callback
263
from omegaconf import DictConfig
264
from hydra.core.utils import JobReturn
265
266
class ResultsAggregatorCallback(Callback):
267
"""Callback for aggregating results from multirun experiments."""
268
269
def __init__(self, output_file: str = "results.json"):
270
self.output_file = output_file
271
self.results: List[Dict[str, Any]] = []
272
273
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
274
self.results = [] # Reset results for new multirun
275
print("Results aggregation started")
276
277
def on_job_end(
278
self,
279
config: DictConfig,
280
job_return: JobReturn,
281
**kwargs: Any
282
) -> None:
283
"""Collect results from each job."""
284
285
result = {
286
'job_id': config.get('hydra', {}).get('job', {}).get('id'),
287
'config': dict(config), # Convert to regular dict for JSON serialization
288
'status': str(job_return.status),
289
'return_value': job_return.return_value,
290
'hydra_cfg': dict(config.get('hydra', {}))
291
}
292
293
self.results.append(result)
294
print(f"Collected result from job {result['job_id']}")
295
296
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
297
"""Save aggregated results to file."""
298
299
output_path = Path(self.output_file)
300
with open(output_path, 'w') as f:
301
json.dump(self.results, f, indent=2, default=str)
302
303
print(f"Results saved to {output_path}")
304
print(f"Total jobs processed: {len(self.results)}")
305
```
306
307
### Callback Registration and Configuration
308
309
```python
310
# Callbacks are typically configured through Hydra's configuration system
311
# or registered programmatically
312
313
from hydra import main, initialize, compose
314
from hydra.core.config_store import ConfigStore
315
from dataclasses import dataclass, field
316
from typing import List
317
318
@dataclass
319
class CallbackConfig:
320
_target_: str
321
# Additional callback-specific parameters
322
323
@dataclass
324
class AppConfig:
325
name: str = "MyApp"
326
callbacks: List[CallbackConfig] = field(default_factory=list)
327
328
# Register callback configs
329
cs = ConfigStore.instance()
330
cs.store(name="logging_callback", node=CallbackConfig(
331
_target_="__main__.LoggingCallback"
332
), group="callbacks")
333
334
cs.store(name="performance_callback", node=CallbackConfig(
335
_target_="__main__.PerformanceCallback"
336
), group="callbacks")
337
338
# Use in configuration files:
339
# config.yaml:
340
# defaults:
341
# - callbacks: [logging_callback, performance_callback]
342
```
343
344
### Integration with Hydra Application
345
346
```python
347
from hydra import main
348
from omegaconf import DictConfig
349
350
# Callbacks are automatically invoked when registered through configuration
351
@main(version_base=None, config_path="conf", config_name="config")
352
def my_app(cfg: DictConfig) -> str:
353
"""Application function with callback integration."""
354
355
print(f"Running application: {cfg.name}")
356
357
# Simulate some work
358
import time
359
time.sleep(1)
360
361
result = f"Processed {cfg.get('items', 0)} items"
362
print(result)
363
364
return result # Return value available in on_job_end callback
365
366
if __name__ == "__main__":
367
my_app()
368
```
369
370
### Advanced Callback Patterns
371
372
```python
373
from hydra.experimental.callback import Callback
374
from omegaconf import DictConfig
375
from typing import Any, Dict
376
import threading
377
378
class ThreadSafeCallback(Callback):
379
"""Thread-safe callback for concurrent job execution."""
380
381
def __init__(self):
382
self._lock = threading.Lock()
383
self._shared_state: Dict[str, Any] = {}
384
385
def on_job_start(
386
self,
387
config: DictConfig,
388
*,
389
task_function: TaskFunction,
390
**kwargs: Any
391
) -> None:
392
with self._lock:
393
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
394
self._shared_state[job_id] = {'status': 'running', 'start_time': time.time()}
395
396
def on_job_end(
397
self,
398
config: DictConfig,
399
job_return: JobReturn,
400
**kwargs: Any
401
) -> None:
402
with self._lock:
403
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
404
if job_id in self._shared_state:
405
self._shared_state[job_id].update({
406
'status': 'completed',
407
'end_time': time.time(),
408
'success': job_return.status == JobReturn.Status.COMPLETED
409
})
410
411
class ConditionalCallback(Callback):
412
"""Callback that only executes under certain conditions."""
413
414
def __init__(self, condition_key: str, condition_value: Any):
415
self.condition_key = condition_key
416
self.condition_value = condition_value
417
418
def _should_execute(self, config: DictConfig) -> bool:
419
"""Check if callback should execute based on configuration."""
420
from omegaconf import OmegaConf
421
422
try:
423
actual_value = OmegaConf.select(config, self.condition_key)
424
return actual_value == self.condition_value
425
except:
426
return False
427
428
def on_job_start(
429
self,
430
config: DictConfig,
431
*,
432
task_function: TaskFunction,
433
**kwargs: Any
434
) -> None:
435
if self._should_execute(config):
436
print(f"Conditional callback triggered for {self.condition_key}={self.condition_value}")
437
```