CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-rl-zoo3

A Training Framework for Stable Baselines3 Reinforcement Learning Agents

Pending
Overview
Eval results
Files

callbacks.mddocs/

Training Callbacks

Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. These callbacks extend Stable Baselines3's callback system with specialized functionality for RL Zoo3's training workflows.

Core Imports

from rl_zoo3.callbacks import (
    TrialEvalCallback,
    SaveVecNormalizeCallback, 
    ParallelTrainCallback,
    RawStatisticsCallback
)
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import optuna

Capabilities

Trial Evaluation Callback

Specialized callback for Optuna hyperparameter optimization trials, providing evaluation and pruning functionality during optimization runs.

class TrialEvalCallback(EvalCallback):
    """
    Callback used for evaluating and reporting a trial during hyperparameter optimization.
    Extends EvalCallback with Optuna trial integration.
    """
    
    def __init__(
        self,
        eval_env: VecEnv,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
        best_model_save_path: Optional[str] = None,
        log_path: Optional[str] = None
    ) -> None:
        """
        Initialize TrialEvalCallback.
        
        Parameters:
        - eval_env: Vectorized evaluation environment
        - trial: Optuna trial object for reporting results
        - n_eval_episodes: Number of episodes for each evaluation
        - eval_freq: Frequency of evaluation (in timesteps)
        - deterministic: Whether to use deterministic actions during evaluation
        - verbose: Verbosity level
        - best_model_save_path: Path to save best model
        - log_path: Path for evaluation logs
        """
    
    def _on_step(self) -> bool:
        """
        Called at each training step. Performs evaluation and reports to Optuna.
        
        Returns:
        bool: Whether training should continue
        """

Usage example:

import optuna
from rl_zoo3.callbacks import TrialEvalCallback
from rl_zoo3 import create_test_env

def objective(trial):
    # Create evaluation environment
    eval_env = create_test_env("CartPole-v1", n_envs=1)
    
    # Create callback
    eval_callback = TrialEvalCallback(
        eval_env=eval_env,
        trial=trial,
        n_eval_episodes=10,
        eval_freq=1000,
        deterministic=True,
        verbose=0
    )
    
    # Use callback in model training (simplified)
    # model.learn(total_timesteps=10000, callback=eval_callback)
    
    return eval_callback.best_mean_reward

# Run optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

VecNormalize Saving Callback

Callback for automatically saving VecNormalize statistics during training, ensuring normalization parameters are preserved with the model.

class SaveVecNormalizeCallback(BaseCallback):
    """
    Callback for saving VecNormalize statistics at regular intervals.
    """
    
    def __init__(
        self,
        save_freq: int,
        save_path: str,
        name_prefix: str = "vecnormalize",
        verbose: int = 0
    ):
        """
        Initialize SaveVecNormalizeCallback.
        
        Parameters:
        - save_freq: Frequency of saving (in timesteps)
        - save_path: Directory path for saving statistics
        - name_prefix: Prefix for saved files
        - verbose: Verbosity level
        """
    
    def _on_step(self) -> bool:
        """
        Called at each training step. Saves VecNormalize stats at specified frequency.
        
        Returns:
        bool: Always True (never stops training)
        """

Usage example:

from rl_zoo3.callbacks import SaveVecNormalizeCallback
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3 import PPO

# Create normalized environment
env = VecNormalize(base_env, norm_obs=True, norm_reward=True)

# Create callback
save_callback = SaveVecNormalizeCallback(
    save_freq=5000,
    save_path="./logs/",
    name_prefix="ppo_vecnormalize",
    verbose=1
)

# Train with callback
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=20000, callback=save_callback)

Parallel Training Callback

Callback for coordinating parallel training processes, handling synchronization and communication between multiple training instances.

class ParallelTrainCallback(BaseCallback):
    """
    Callback for parallel training coordination.
    Handles synchronization between multiple training processes.
    """
    
    def __init__(self, verbose: int = 0):
        """
        Initialize ParallelTrainCallback.
        
        Parameters:
        - verbose: Verbosity level
        """
    
    def _on_training_start(self) -> None:
        """
        Called when training starts. Sets up parallel coordination.
        """
    
    def _on_step(self) -> bool:
        """
        Called at each training step. Handles parallel synchronization.
        
        Returns:
        bool: Whether training should continue
        """
    
    def _on_training_end(self) -> None:
        """
        Called when training ends. Cleans up parallel resources.
        """

Raw Statistics Callback

Callback for logging detailed training statistics and metrics, providing comprehensive monitoring of training progress.

class RawStatisticsCallback(BaseCallback):
    """
    Callback for logging raw training statistics.
    Provides detailed metrics beyond standard Stable Baselines3 logging.
    """
    
    def __init__(
        self,
        verbose: int = 0,
        log_freq: int = 1000
    ):
        """
        Initialize RawStatisticsCallback.
        
        Parameters:
        - verbose: Verbosity level
        - log_freq: Frequency of detailed logging (in timesteps)
        """
    
    def _on_step(self) -> bool:
        """
        Called at each training step. Logs detailed statistics at specified frequency.
        
        Returns:
        bool: Always True (never stops training)
        """
    
    def _on_rollout_end(self) -> None:
        """
        Called at the end of each rollout. Logs rollout statistics.
        """

Usage example:

from rl_zoo3.callbacks import RawStatisticsCallback
from stable_baselines3 import PPO

# Create callback
stats_callback = RawStatisticsCallback(
    verbose=1,
    log_freq=1000
)

# Train with detailed statistics logging
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=50000, callback=stats_callback)

Callback Combinations

You can combine multiple callbacks for comprehensive training monitoring:

from rl_zoo3.callbacks import SaveVecNormalizeCallback, RawStatisticsCallback
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList

# Create multiple callbacks
checkpoint_callback = CheckpointCallback(
    save_freq=10000,
    save_path="./checkpoints/",
    name_prefix="ppo_model"
)

eval_callback = EvalCallback(
    eval_env=eval_env,
    best_model_save_path="./best_models/",
    log_path="./eval_logs/",
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True
)

save_vec_callback = SaveVecNormalizeCallback(
    save_freq=5000,
    save_path="./vecnormalize/",
    verbose=1
)

stats_callback = RawStatisticsCallback(
    verbose=1,
    log_freq=1000
)

# Combine callbacks
callback_list = CallbackList([
    checkpoint_callback,
    eval_callback,
    save_vec_callback,
    stats_callback
])

# Train with all callbacks
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000, callback=callback_list)

Integration with ExperimentManager

The ExperimentManager automatically creates and configures appropriate callbacks based on the training setup:

from rl_zoo3.exp_manager import ExperimentManager
import argparse

# ExperimentManager will automatically create callbacks based on configuration
args = argparse.Namespace(
    algo='ppo',
    env='CartPole-v1',
    n_timesteps=50000,
    eval_freq=5000,  # Will create EvalCallback
    save_freq=10000,  # Will create CheckpointCallback
    verbose=1
)

exp_manager = ExperimentManager(
    args=args,
    algo='ppo',
    env_id='CartPole-v1',
    log_folder='./logs',
    n_timesteps=50000,
    eval_freq=5000,
    save_freq=10000
)

# Callbacks are automatically created and configured
model = exp_manager.setup_experiment()
exp_manager.learn(model)  # Uses automatically created callbacks

Custom Callback Integration

You can also create custom callbacks that work with RL Zoo3's training system:

from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    """
    Custom callback for specific training needs.
    """
    
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.custom_metric = 0
    
    def _on_training_start(self) -> None:
        print("Custom training logic started")
    
    def _on_step(self) -> bool:
        # Custom logic here
        self.custom_metric += 1
        
        # Log custom metrics
        if self.n_calls % 1000 == 0:
            self.logger.record("custom/metric", self.custom_metric)
        
        return True  # Continue training

# Use custom callback with ExperimentManager
# You would typically integrate this through hyperparams configuration
# or by modifying the ExperimentManager's create_callbacks method

Install with Tessl CLI

npx tessl i tessl/pypi-rl-zoo3

docs

callbacks.md

core-utilities.md

experiment-management.md

hub-integration.md

hyperparameter-optimization.md

index.md

plotting.md

wrappers.md

tile.json