A Training Framework for Stable Baselines3 Reinforcement Learning Agents
—
Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. These callbacks extend Stable Baselines3's callback system with specialized functionality for RL Zoo3's training workflows.
from rl_zoo3.callbacks import (
TrialEvalCallback,
SaveVecNormalizeCallback,
ParallelTrainCallback,
RawStatisticsCallback
)
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import optunaSpecialized callback for Optuna hyperparameter optimization trials, providing evaluation and pruning functionality during optimization runs.
class TrialEvalCallback(EvalCallback):
"""
Callback used for evaluating and reporting a trial during hyperparameter optimization.
Extends EvalCallback with Optuna trial integration.
"""
def __init__(
self,
eval_env: VecEnv,
trial: optuna.Trial,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
deterministic: bool = True,
verbose: int = 0,
best_model_save_path: Optional[str] = None,
log_path: Optional[str] = None
) -> None:
"""
Initialize TrialEvalCallback.
Parameters:
- eval_env: Vectorized evaluation environment
- trial: Optuna trial object for reporting results
- n_eval_episodes: Number of episodes for each evaluation
- eval_freq: Frequency of evaluation (in timesteps)
- deterministic: Whether to use deterministic actions during evaluation
- verbose: Verbosity level
- best_model_save_path: Path to save best model
- log_path: Path for evaluation logs
"""
def _on_step(self) -> bool:
"""
Called at each training step. Performs evaluation and reports to Optuna.
Returns:
bool: Whether training should continue
"""Usage example:
import optuna
from rl_zoo3.callbacks import TrialEvalCallback
from rl_zoo3 import create_test_env
def objective(trial):
# Create evaluation environment
eval_env = create_test_env("CartPole-v1", n_envs=1)
# Create callback
eval_callback = TrialEvalCallback(
eval_env=eval_env,
trial=trial,
n_eval_episodes=10,
eval_freq=1000,
deterministic=True,
verbose=0
)
# Use callback in model training (simplified)
# model.learn(total_timesteps=10000, callback=eval_callback)
return eval_callback.best_mean_reward
# Run optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)Callback for automatically saving VecNormalize statistics during training, ensuring normalization parameters are preserved with the model.
class SaveVecNormalizeCallback(BaseCallback):
"""
Callback for saving VecNormalize statistics at regular intervals.
"""
def __init__(
self,
save_freq: int,
save_path: str,
name_prefix: str = "vecnormalize",
verbose: int = 0
):
"""
Initialize SaveVecNormalizeCallback.
Parameters:
- save_freq: Frequency of saving (in timesteps)
- save_path: Directory path for saving statistics
- name_prefix: Prefix for saved files
- verbose: Verbosity level
"""
def _on_step(self) -> bool:
"""
Called at each training step. Saves VecNormalize stats at specified frequency.
Returns:
bool: Always True (never stops training)
"""Usage example:
from rl_zoo3.callbacks import SaveVecNormalizeCallback
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3 import PPO
# Create normalized environment
env = VecNormalize(base_env, norm_obs=True, norm_reward=True)
# Create callback
save_callback = SaveVecNormalizeCallback(
save_freq=5000,
save_path="./logs/",
name_prefix="ppo_vecnormalize",
verbose=1
)
# Train with callback
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=20000, callback=save_callback)Callback for coordinating parallel training processes, handling synchronization and communication between multiple training instances.
class ParallelTrainCallback(BaseCallback):
"""
Callback for parallel training coordination.
Handles synchronization between multiple training processes.
"""
def __init__(self, verbose: int = 0):
"""
Initialize ParallelTrainCallback.
Parameters:
- verbose: Verbosity level
"""
def _on_training_start(self) -> None:
"""
Called when training starts. Sets up parallel coordination.
"""
def _on_step(self) -> bool:
"""
Called at each training step. Handles parallel synchronization.
Returns:
bool: Whether training should continue
"""
def _on_training_end(self) -> None:
"""
Called when training ends. Cleans up parallel resources.
"""Callback for logging detailed training statistics and metrics, providing comprehensive monitoring of training progress.
class RawStatisticsCallback(BaseCallback):
"""
Callback for logging raw training statistics.
Provides detailed metrics beyond standard Stable Baselines3 logging.
"""
def __init__(
self,
verbose: int = 0,
log_freq: int = 1000
):
"""
Initialize RawStatisticsCallback.
Parameters:
- verbose: Verbosity level
- log_freq: Frequency of detailed logging (in timesteps)
"""
def _on_step(self) -> bool:
"""
Called at each training step. Logs detailed statistics at specified frequency.
Returns:
bool: Always True (never stops training)
"""
def _on_rollout_end(self) -> None:
"""
Called at the end of each rollout. Logs rollout statistics.
"""Usage example:
from rl_zoo3.callbacks import RawStatisticsCallback
from stable_baselines3 import PPO
# Create callback
stats_callback = RawStatisticsCallback(
verbose=1,
log_freq=1000
)
# Train with detailed statistics logging
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=50000, callback=stats_callback)You can combine multiple callbacks for comprehensive training monitoring:
from rl_zoo3.callbacks import SaveVecNormalizeCallback, RawStatisticsCallback
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList
# Create multiple callbacks
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path="./checkpoints/",
name_prefix="ppo_model"
)
eval_callback = EvalCallback(
eval_env=eval_env,
best_model_save_path="./best_models/",
log_path="./eval_logs/",
eval_freq=5000,
n_eval_episodes=10,
deterministic=True
)
save_vec_callback = SaveVecNormalizeCallback(
save_freq=5000,
save_path="./vecnormalize/",
verbose=1
)
stats_callback = RawStatisticsCallback(
verbose=1,
log_freq=1000
)
# Combine callbacks
callback_list = CallbackList([
checkpoint_callback,
eval_callback,
save_vec_callback,
stats_callback
])
# Train with all callbacks
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000, callback=callback_list)The ExperimentManager automatically creates and configures appropriate callbacks based on the training setup:
from rl_zoo3.exp_manager import ExperimentManager
import argparse
# ExperimentManager will automatically create callbacks based on configuration
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=50000,
eval_freq=5000, # Will create EvalCallback
save_freq=10000, # Will create CheckpointCallback
verbose=1
)
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./logs',
n_timesteps=50000,
eval_freq=5000,
save_freq=10000
)
# Callbacks are automatically created and configured
model = exp_manager.setup_experiment()
exp_manager.learn(model) # Uses automatically created callbacksYou can also create custom callbacks that work with RL Zoo3's training system:
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
Custom callback for specific training needs.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.custom_metric = 0
def _on_training_start(self) -> None:
print("Custom training logic started")
def _on_step(self) -> bool:
# Custom logic here
self.custom_metric += 1
# Log custom metrics
if self.n_calls % 1000 == 0:
self.logger.record("custom/metric", self.custom_metric)
return True # Continue training
# Use custom callback with ExperimentManager
# You would typically integrate this through hyperparams configuration
# or by modifying the ExperimentManager's create_callbacks methodInstall with Tessl CLI
npx tessl i tessl/pypi-rl-zoo3