CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-hydra-core

A framework for elegantly configuring complex applications

Pending
Overview
Eval results
Files

callbacks.mddocs/

Callbacks

Experimental callback API for hooking into Hydra's execution lifecycle. Callbacks enable custom logic at different stages of application execution including run start/end, multirun events, and individual job events.

Capabilities

Callback Base Class

Base class for implementing custom callbacks that respond to Hydra execution events.

class Callback:
    """Base class for Hydra callbacks."""
    
    def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
        """
        Called in RUN mode before job/application code starts.
        
        Parameters:
        - config: Composed configuration with overrides applied
        - **kwargs: Additional context (future extensibility)
        
        Note: Some hydra.runtime configs may not be populated yet.
        """
    
    def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
        """
        Called in RUN mode after job/application code returns.
        
        Parameters:
        - config: The configuration used for the run
        - **kwargs: Additional context
        """
    
    def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
        """
        Called in MULTIRUN mode before any job starts.
        
        Parameters:
        - config: Base configuration before parameter sweeps
        - **kwargs: Additional context
        
        Note: When using a launcher, this executes on local machine
        before any Sweeper/Launcher is initialized.
        """
    
    def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
        """
        Called in MULTIRUN mode after all jobs return.
        
        Parameters:
        - config: Base configuration
        - **kwargs: Additional context
        
        Note: When using a launcher, this executes on local machine.
        """
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        """
        Called in both RUN and MULTIRUN modes for each Hydra job.
        
        Parameters:
        - config: Configuration for this specific job
        - task_function: The function decorated with @hydra.main
        - **kwargs: Additional context
        
        Note: In remote launching, this executes on the remote server
        along with your application code.
        """
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        """
        Called in both RUN and MULTIRUN modes after each job completes.
        
        Parameters:
        - config: Configuration for the completed job
        - job_return: Information about job execution and results
        - **kwargs: Additional context
        
        Note: In remote launching, this executes on the remote server
        after your application code.
        """

Usage Examples

Basic Callback Implementation

from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.types import TaskFunction
from hydra.core.utils import JobReturn
import logging
from typing import Any

class LoggingCallback(Callback):
    """Simple callback that logs execution events."""
    
    def __init__(self):
        self.logger = logging.getLogger(__name__)
    
    def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
        self.logger.info(f"Starting run with config: {config.get('name', 'unnamed')}")
    
    def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
        self.logger.info("Run completed")
    
    def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
        self.logger.info("Starting multirun")
    
    def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
        self.logger.info("Multirun completed")
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
        self.logger.info(f"Starting job: {job_name}")
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        job_name = config.get('hydra', {}).get('job', {}).get('name', 'unknown')
        status = "SUCCESS" if job_return.status == JobReturn.Status.COMPLETED else "FAILED"
        self.logger.info(f"Job {job_name} finished with status: {status}")

Performance Monitoring Callback

import time
from typing import Dict, Any
from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.core.utils import JobReturn

class PerformanceCallback(Callback):
    """Callback for monitoring execution performance."""
    
    def __init__(self):
        self.start_times: Dict[str, float] = {}
        self.metrics: Dict[str, Any] = {}
    
    def on_run_start(self, config: DictConfig, **kwargs: Any) -> None:
        self.start_times['run'] = time.time()
        print("Performance monitoring started")
    
    def on_run_end(self, config: DictConfig, **kwargs: Any) -> None:
        duration = time.time() - self.start_times['run']
        print(f"Total execution time: {duration:.2f} seconds")
    
    def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
        self.start_times['multirun'] = time.time()
        self.metrics['jobs_completed'] = 0
        print("Multirun performance monitoring started")
    
    def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
        total_duration = time.time() - self.start_times['multirun']
        jobs = self.metrics.get('jobs_completed', 0)
        avg_job_time = total_duration / jobs if jobs > 0 else 0
        
        print(f"Multirun completed in {total_duration:.2f} seconds")
        print(f"Jobs completed: {jobs}")
        print(f"Average job time: {avg_job_time:.2f} seconds")
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
        self.start_times[f'job_{job_id}'] = time.time()
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
        start_key = f'job_{job_id}'
        
        if start_key in self.start_times:
            duration = time.time() - self.start_times[start_key]
            print(f"Job {job_id} completed in {duration:.2f} seconds")
            del self.start_times[start_key]
        
        self.metrics['jobs_completed'] = self.metrics.get('jobs_completed', 0) + 1

Configuration Validation Callback

from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from typing import Any

class ValidationCallback(Callback):
    """Callback for validating configurations."""
    
    def __init__(self, required_keys: list = None):
        self.required_keys = required_keys or []
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        """Validate configuration before job execution."""
        
        # Check required keys
        for key in self.required_keys:
            if key not in config:
                raise ValueError(f"Required configuration key missing: {key}")
        
        # Custom validation logic
        if hasattr(config, 'database') and config.database:
            if config.database.get('port', 0) <= 0:
                raise ValueError("Database port must be positive")
        
        print("Configuration validation passed")
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        """Log job completion status."""
        if job_return.status == JobReturn.Status.FAILED:
            print(f"Job failed with configuration: {config}")

Results Aggregation Callback

import json
from pathlib import Path
from typing import List, Any
from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from hydra.core.utils import JobReturn

class ResultsAggregatorCallback(Callback):
    """Callback for aggregating results from multirun experiments."""
    
    def __init__(self, output_file: str = "results.json"):
        self.output_file = output_file
        self.results: List[Dict[str, Any]] = []
    
    def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
        self.results = []  # Reset results for new multirun
        print("Results aggregation started")
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        """Collect results from each job."""
        
        result = {
            'job_id': config.get('hydra', {}).get('job', {}).get('id'),
            'config': dict(config),  # Convert to regular dict for JSON serialization
            'status': str(job_return.status),
            'return_value': job_return.return_value,
            'hydra_cfg': dict(config.get('hydra', {}))
        }
        
        self.results.append(result)
        print(f"Collected result from job {result['job_id']}")
    
    def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
        """Save aggregated results to file."""
        
        output_path = Path(self.output_file)
        with open(output_path, 'w') as f:
            json.dump(self.results, f, indent=2, default=str)
        
        print(f"Results saved to {output_path}")
        print(f"Total jobs processed: {len(self.results)}")

Callback Registration and Configuration

# Callbacks are typically configured through Hydra's configuration system
# or registered programmatically

from hydra import main, initialize, compose
from hydra.core.config_store import ConfigStore
from dataclasses import dataclass, field
from typing import List

@dataclass
class CallbackConfig:
    _target_: str
    # Additional callback-specific parameters

@dataclass
class AppConfig:
    name: str = "MyApp"
    callbacks: List[CallbackConfig] = field(default_factory=list)

# Register callback configs
cs = ConfigStore.instance()
cs.store(name="logging_callback", node=CallbackConfig(
    _target_="__main__.LoggingCallback"
), group="callbacks")

cs.store(name="performance_callback", node=CallbackConfig(
    _target_="__main__.PerformanceCallback"  
), group="callbacks")

# Use in configuration files:
# config.yaml:
# defaults:
#   - callbacks: [logging_callback, performance_callback]

Integration with Hydra Application

from hydra import main
from omegaconf import DictConfig

# Callbacks are automatically invoked when registered through configuration
@main(version_base=None, config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> str:
    """Application function with callback integration."""
    
    print(f"Running application: {cfg.name}")
    
    # Simulate some work
    import time
    time.sleep(1)
    
    result = f"Processed {cfg.get('items', 0)} items"
    print(result)
    
    return result  # Return value available in on_job_end callback

if __name__ == "__main__":
    my_app()

Advanced Callback Patterns

from hydra.experimental.callback import Callback
from omegaconf import DictConfig
from typing import Any, Dict
import threading

class ThreadSafeCallback(Callback):
    """Thread-safe callback for concurrent job execution."""
    
    def __init__(self):
        self._lock = threading.Lock()
        self._shared_state: Dict[str, Any] = {}
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        with self._lock:
            job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
            self._shared_state[job_id] = {'status': 'running', 'start_time': time.time()}
    
    def on_job_end(
        self, 
        config: DictConfig, 
        job_return: JobReturn, 
        **kwargs: Any
    ) -> None:
        with self._lock:
            job_id = config.get('hydra', {}).get('job', {}).get('id', 'unknown')
            if job_id in self._shared_state:
                self._shared_state[job_id].update({
                    'status': 'completed',
                    'end_time': time.time(),
                    'success': job_return.status == JobReturn.Status.COMPLETED
                })

class ConditionalCallback(Callback):
    """Callback that only executes under certain conditions."""
    
    def __init__(self, condition_key: str, condition_value: Any):
        self.condition_key = condition_key
        self.condition_value = condition_value
    
    def _should_execute(self, config: DictConfig) -> bool:
        """Check if callback should execute based on configuration."""
        from omegaconf import OmegaConf
        
        try:
            actual_value = OmegaConf.select(config, self.condition_key)
            return actual_value == self.condition_value
        except:
            return False
    
    def on_job_start(
        self, 
        config: DictConfig, 
        *, 
        task_function: TaskFunction, 
        **kwargs: Any
    ) -> None:
        if self._should_execute(config):
            print(f"Conditional callback triggered for {self.condition_key}={self.condition_value}")

Install with Tessl CLI

npx tessl i tessl/pypi-hydra-core

docs

callbacks.md

composition.md

config-schema.md

config-store.md

errors.md

index.md

initialization.md

main-decorator.md

types.md

utilities.md

tile.json