A framework for elegantly configuring complex applications
—
Experimental callback API for hooking into Hydra's execution lifecycle. Callbacks enable custom logic at different stages of application execution including run start/end, multirun events, and individual job events.
Base class for implementing custom callbacks that respond to Hydra execution events.
class Callback:
"""Base class for Hydra callbacks."""
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
"""
Called in RUN mode before job/application code starts.
Parameters:
- config: Composed configuration with overrides applied
- **kwargs: Additional context (future extensibility)
Note: Some hydra.runtime configs may not be populated yet.
"""
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
"""
Called in RUN mode after job/application code returns.
Parameters:
- config: The configuration used for the run
- **kwargs: Additional context
"""
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
"""
Called in MULTIRUN mode before any job starts.
Parameters:
- config: Base configuration before parameter sweeps
- **kwargs: Additional context
Note: When using a launcher, this executes on local machine
before any Sweeper/Launcher is initialized.
"""
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
"""
Called in MULTIRUN mode after all jobs return.
Parameters:
- config: Base configuration
- **kwargs: Additional context
Note: When using a launcher, this executes on local machine.
"""
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
"""
Called in both RUN and MULTIRUN modes for each Hydra job.
Parameters:
- config: Configuration for this specific job
- task_function: The function decorated with @hydra.main
- **kwargs: Additional context
Note: In remote launching, this executes on the remote server
along with your application code.
"""
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
"""
Called in both RUN and MULTIRUN modes after each job completes.
Parameters:
- config: Configuration for the completed job
- job_return: Information about job execution and results
- **kwargs: Additional context
Note: In remote launching, this executes on the remote server
after your application code.
"""from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.types import TaskFunction
from hydra.core.utils import JobReturn
import logging
from typing import Any
class LoggingCallback(Callback):
"""Simple callback that logs execution events."""
def __init__(self):
self.logger = logging.getLogger(__name__)
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
self.logger.info(f"Starting run with config: {config.get('name', 'unnamed')}")
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
self.logger.info("Run completed")
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
self.logger.info("Starting multirun")
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
self.logger.info("Multirun completed")
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
self.logger.info(f"Starting job: {job_name}")
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
status = "SUCCESS" if job_return.status == JobReturn.Status.COMPLETED else "FAILED"
self.logger.info(f"Job {job_name} finished with status: {status}")import time
from typing import Dict, Any
from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.core.utils import JobReturn
class PerformanceCallback(Callback):
"""Callback for monitoring execution performance."""
def __init__(self):
self.start_times: Dict[str, float] = {}
self.metrics: Dict[str, Any] = {}
def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
self.start_times['run'] = time.time()
print("Performance monitoring started")
def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
duration = time.time() - self.start_times['run']
print(f"Total execution time: {duration:.2f} seconds")
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
self.start_times['multirun'] = time.time()
self.metrics['jobs_completed'] = 0
print("Multirun performance monitoring started")
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
total_duration = time.time() - self.start_times['multirun']
jobs = self.metrics.get('jobs_completed', 0)
avg_job_time = total_duration / jobs if jobs > 0 else 0
print(f"Multirun completed in {total_duration:.2f} seconds")
print(f"Jobs completed: {jobs}")
print(f"Average job time: {avg_job_time:.2f} seconds")
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
self.start_times[f'job_{job_id}'] = time.time()
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
start_key = f'job_{job_id}'
if start_key in self.start_times:
duration = time.time() - self.start_times[start_key]
print(f"Job {job_id} completed in {duration:.2f} seconds")
del self.start_times[start_key]
self.metrics['jobs_completed'] = self.metrics.get('jobs_completed', 0) + 1from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from typing import Any
class ValidationCallback(Callback):
"""Callback for validating configurations."""
def __init__(self, required_keys: list = None):
self.required_keys = required_keys or []
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
"""Validate configuration before job execution."""
# Check required keys
for key in self.required_keys:
if key not in config:
raise ValueError(f"Required configuration key missing: {key}")
# Custom validation logic
if hasattr(config, 'database') and config.database:
if config.database.get('port', 0) <= 0:
raise ValueError("Database port must be positive")
print("Configuration validation passed")
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
"""Log job completion status."""
if job_return.status == JobReturn.Status.FAILED:
print(f"Job failed with configuration: {config}")import json
from pathlib import Path
from typing import List, Any
from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.core.utils import JobReturn
class ResultsAggregatorCallback(Callback):
"""Callback for aggregating results from multirun experiments."""
def __init__(self, output_file: str = "results.json"):
self.output_file = output_file
self.results: List[Dict[str, Any]] = []
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
self.results = [] # Reset results for new multirun
print("Results aggregation started")
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
"""Collect results from each job."""
result = {
'job_id': config.get('hydra', {}).get('job', {}).get('id'),
'config': dict(config), # Convert to regular dict for JSON serialization
'status': str(job_return.status),
'return_value': job_return.return_value,
'hydra_cfg': dict(config.get('hydra', {}))
}
self.results.append(result)
print(f"Collected result from job {result['job_id']}")
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
"""Save aggregated results to file."""
output_path = Path(self.output_file)
with open(output_path, 'w') as f:
json.dump(self.results, f, indent=2, default=str)
print(f"Results saved to {output_path}")
print(f"Total jobs processed: {len(self.results)}")# Callbacks are typically configured through Hydra's configuration system
# or registered programmatically
from hydra import main, initialize, compose
from hydra.core.config_store import ConfigStore
from dataclasses import dataclass, field
from typing import List
@dataclass
class CallbackConfig:
_target_: str
# Additional callback-specific parameters
@dataclass
class AppConfig:
name: str = "MyApp"
callbacks: List[CallbackConfig] = field(default_factory=list)
# Register callback configs
cs = ConfigStore.instance()
cs.store(name="logging_callback", node=CallbackConfig(
_target_="__main__.LoggingCallback"
), group="callbacks")
cs.store(name="performance_callback", node=CallbackConfig(
_target_="__main__.PerformanceCallback"
), group="callbacks")
# Use in configuration files:
# config.yaml:
# defaults:
# - callbacks: [logging_callback, performance_callback]from hydra import main
from omegaconf import DictConfig
# Callbacks are automatically invoked when registered through configuration
@main(version_base=None, config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> str:
"""Application function with callback integration."""
print(f"Running application: {cfg.name}")
# Simulate some work
import time
time.sleep(1)
result = f"Processed {cfg.get('items', 0)} items"
print(result)
return result # Return value available in on_job_end callback
if __name__ == "__main__":
my_app()from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from typing import Any, Dict
import threading
class ThreadSafeCallback(Callback):
"""Thread-safe callback for concurrent job execution."""
def __init__(self):
self._lock = threading.Lock()
self._shared_state: Dict[str, Any] = {}
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
with self._lock:
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
self._shared_state[job_id] = {'status': 'running', 'start_time': time.time()}
def on_job_end(
self,
config: DictConfig,
job_return: JobReturn,
**kwargs: Any
) -> None:
with self._lock:
job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
if job_id in self._shared_state:
self._shared_state[job_id].update({
'status': 'completed',
'end_time': time.time(),
'success': job_return.status == JobReturn.Status.COMPLETED
})
class ConditionalCallback(Callback):
"""Callback that only executes under certain conditions."""
def __init__(self, condition_key: str, condition_value: Any):
self.condition_key = condition_key
self.condition_value = condition_value
def _should_execute(self, config: DictConfig) -> bool:
"""Check if callback should execute based on configuration."""
from omegaconf import OmegaConf
try:
actual_value = OmegaConf.select(config, self.condition_key)
return actual_value == self.condition_value
except:
return False
def on_job_start(
self,
config: DictConfig,
*,
task_function: TaskFunction,
**kwargs: Any
) -> None:
if self._should_execute(config):
print(f"Conditional callback triggered for {self.condition_key}={self.condition_value}")Install with Tessl CLI
npx tessl i tessl/pypi-hydra-core