A Training Framework for Stable Baselines3 Reinforcement Learning Agents
—
Comprehensive experiment orchestration through the ExperimentManager class, which handles training workflows, hyperparameter optimization, environment setup, and model coordination. This is the central component that ties together all aspects of RL training experiments.
from rl_zoo3.exp_manager import ExperimentManager
import argparse
from typing import Optional, AnyThe main class for managing RL experiments, from initial setup through training and evaluation. Handles hyperparameter loading, environment creation, model instantiation, and training coordination.
class ExperimentManager:
"""
Experiment manager: read the hyperparameters,
preprocess them, create the environment and the RL model.
"""
def __init__(
self,
args: argparse.Namespace,
algo: str,
env_id: str,
log_folder: str,
tensorboard_log: str = "",
n_timesteps: int = 0,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
save_freq: int = -1,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
eval_env_kwargs: Optional[dict[str, Any]] = None,
trained_agent: str = "",
optimize_hyperparameters: bool = False,
storage: Optional[str] = None,
study_name: Optional[str] = None,
n_trials: int = 1,
max_total_trials: Optional[int] = None,
n_jobs: int = 1,
sampler: str = "tpe",
pruner: str = "median",
optimization_log_path: Optional[str] = None,
n_startup_trials: int = 0,
n_evaluations: int = 1,
truncate_last_trajectory: bool = False,
uuid_str: str = "",
seed: int = 0,
log_interval: int = 0,
save_replay_buffer: bool = False,
verbose: int = 1,
vec_env_type: str = "dummy",
n_eval_envs: int = 1,
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
config: Optional[str] = None,
show_progress: bool = False,
trial_id: Optional[int] = None
):
"""
Initialize ExperimentManager.
Parameters:
- args: Command line arguments namespace
- algo: Algorithm name (must be in ALGOS dict)
- env_id: Environment identifier
- log_folder: Directory for saving logs and models
- tensorboard_log: Tensorboard logging directory
- n_timesteps: Total training timesteps
- eval_freq: Frequency of evaluation (in timesteps)
- n_eval_episodes: Number of episodes for evaluation
- save_freq: Frequency of model saving (-1 to disable)
- hyperparams: Override hyperparameters
- env_kwargs: Environment creation arguments
- eval_env_kwargs: Evaluation environment arguments
- trained_agent: Path to pre-trained agent to load
- optimize_hyperparameters: Whether to run hyperparameter optimization
- storage: Optuna storage URL for hyperparameter optimization
- study_name: Optuna study name
- n_trials: Number of hyperparameter optimization trials
- max_total_trials: Maximum total trials across all processes
- n_jobs: Number of parallel jobs for optimization
- sampler: Optuna sampler ('tpe', 'random', 'cmaes')
- pruner: Optuna pruner ('median', 'successive_halving', 'hyperband')
- optimization_log_path: Path for optimization logs
- n_startup_trials: Number of startup trials for pruner
- n_evaluations: Number of evaluations per trial
- truncate_last_trajectory: Whether to truncate last trajectory
- uuid_str: Unique identifier string
- seed: Random seed
- log_interval: Logging interval during training
- save_replay_buffer: Whether to save replay buffer
- verbose: Verbosity level
- vec_env_type: Type of vectorized environment ('dummy', 'subproc')
- n_eval_envs: Number of parallel evaluation environments
- no_optim_plots: Whether to disable optimization plots
- device: Device to use ('auto', 'cpu', 'cuda', torch.device)
- config: Path to configuration file
- show_progress: Whether to show progress bar
- trial_id: Optional trial ID for hyperparameter optimization
"""Core methods for setting up and configuring experiments before training begins.
def setup_experiment(self) -> BaseAlgorithm:
"""
Set up the experiment: load hyperparameters, create environments, and instantiate the model.
Returns:
BaseAlgorithm: Configured RL model ready for training
"""
def create_log_folder(self) -> None:
"""
Create log folder and set up logging directories.
"""
def create_callbacks(self) -> list[BaseCallback]:
"""
Create training callbacks based on configuration.
Returns:
list[BaseCallback]: List of configured callbacks
"""Methods for creating and managing training and evaluation environments with proper configuration and wrappers.
def create_envs(self, n_envs: int, eval_env: bool = False) -> VecEnv:
"""
Create vectorized environments for training or evaluation.
Parameters:
- n_envs: Number of parallel environments
- eval_env: Whether this is for evaluation
Returns:
VecEnv: Configured vectorized environment
"""
def get_env_kwargs(self) -> dict[str, Any]:
"""
Get environment creation keyword arguments.
Returns:
dict: Environment kwargs
"""Methods for creating new models or loading pre-trained models with proper configuration.
def create_model(self) -> BaseAlgorithm:
"""
Create a new RL model with loaded hyperparameters.
Returns:
BaseAlgorithm: Configured RL model
"""
def load_trained_model(self) -> BaseAlgorithm:
"""
Load a pre-trained model.
Returns:
BaseAlgorithm: Loaded RL model
"""Methods for executing the training process with proper monitoring and checkpointing.
def learn(self, model: BaseAlgorithm) -> None:
"""
Train the model with configured parameters and callbacks.
Parameters:
- model: RL model to train
"""
def save_trained_model(self, model: BaseAlgorithm) -> None:
"""
Save the trained model and associated files.
Parameters:
- model: Trained RL model to save
"""Methods for running hyperparameter optimization using Optuna with distributed training support.
def hyperparameters_optimization(self) -> None:
"""
Run hyperparameter optimization using Optuna.
Supports distributed optimization across multiple processes.
"""
def objective(self, trial: optuna.Trial) -> float:
"""
Optuna objective function for hyperparameter optimization.
Parameters:
- trial: Optuna trial object
Returns:
float: Trial objective value (reward)
"""Methods for reading, loading, and preprocessing hyperparameters and configuration files.
def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Read hyperparameters from YAML configuration files.
Returns:
tuple[dict[str, Any], dict[str, Any]]: (hyperparams, saved_hyperparams)
"""
def load_trial(self, trial_id: int) -> None:
"""
Load a specific Optuna trial configuration.
Parameters:
- trial_id: ID of the trial to load
"""
def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:
"""
Save configuration and hyperparameters to log directory.
Parameters:
- saved_hyperparams: Hyperparameters to save
"""Internal methods for preprocessing hyperparameters and configuration before training.
@staticmethod
def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess learning rate and other parameter schedules.
Parameters:
- hyperparams: Raw hyperparameters
Returns:
dict[str, Any]: Processed hyperparameters with schedule objects
"""
def _preprocess_normalization(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess VecNormalize parameters.
Parameters:
- hyperparams: Raw hyperparameters
Returns:
dict[str, Any]: Processed hyperparameters with normalization config
"""
def _preprocess_hyperparams(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess all hyperparameters before model creation.
Parameters:
- hyperparams: Raw hyperparameters
Returns:
dict[str, Any]: Fully processed hyperparameters
"""
def _preprocess_action_noise(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess action noise parameters for algorithms that support it.
Parameters:
- hyperparams: Raw hyperparameters
Returns:
dict[str, Any]: Processed hyperparameters with action noise objects
"""Methods for environment creation, model loading, and related utilities.
def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
"""
Apply VecNormalize wrapper if specified in hyperparameters.
Parameters:
- env: Vector environment
- eval_env: Whether this is an evaluation environment
Returns:
VecEnv: Potentially normalized environment
"""
def _load_pretrained_agent(self, hyperparams: dict[str, Any], env: VecEnv) -> BaseAlgorithm:
"""
Load a pretrained agent for transfer learning or continued training.
Parameters:
- hyperparams: Model hyperparameters
- env: Training environment
Returns:
BaseAlgorithm: Loaded pretrained model
"""Methods for creating Optuna samplers and pruners for hyperparameter optimization.
def _create_sampler(self, sampler_method: str) -> BaseSampler:
"""
Create Optuna sampler for hyperparameter optimization.
Parameters:
- sampler_method: Sampler type ("tpe", "random", "cmaes")
Returns:
BaseSampler: Configured Optuna sampler
"""
def _create_pruner(self, pruner_method: str) -> BasePruner:
"""
Create Optuna pruner for early stopping of unpromising trials.
Parameters:
- pruner_method: Pruner type ("median", "successive_halving", "nop")
Returns:
BasePruner: Configured Optuna pruner
"""Static methods for detecting specific environment types and applying appropriate configurations.
@staticmethod
def entry_point(env_id: str) -> str:
"""
Get the entry point for a given environment ID.
Parameters:
- env_id: Environment identifier
Returns:
str: Entry point string
"""
@staticmethod
def is_atari(env_id: str) -> bool:
"""
Check if environment is an Atari environment.
Parameters:
- env_id: Environment identifier
Returns:
bool: True if Atari environment
"""
@staticmethod
def is_minigrid(env_id: str) -> bool:
"""
Check if environment is a MiniGrid environment.
Parameters:
- env_id: Environment identifier
Returns:
bool: True if MiniGrid environment
"""
@staticmethod
def is_bullet(env_id: str) -> bool:
"""
Check if environment is a PyBullet environment.
Parameters:
- env_id: Environment identifier
Returns:
bool: True if PyBullet environment
"""
@staticmethod
def is_robotics_env(env_id: str) -> bool:
"""
Check if environment is a robotics environment.
Parameters:
- env_id: Environment identifier
Returns:
bool: True if robotics environment
"""
@staticmethod
def is_panda_gym(env_id: str) -> bool:
"""
Check if environment is a Panda Gym environment.
Parameters:
- env_id: Environment identifier
Returns:
bool: True if Panda Gym environment
"""import argparse
from rl_zoo3.exp_manager import ExperimentManager
# Create arguments (typically from command line)
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=10000,
eval_freq=1000,
n_eval_episodes=5,
save_freq=-1,
verbose=1,
seed=42
)
# Create experiment manager
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./logs',
n_timesteps=10000,
eval_freq=1000,
seed=42
)
# Setup and train
model = exp_manager.setup_experiment()
exp_manager.learn(model)
exp_manager.save_trained_model(model)import argparse
from rl_zoo3.exp_manager import ExperimentManager
# Advanced configuration
args = argparse.Namespace(
algo='sac',
env='Pendulum-v1',
n_timesteps=50000,
eval_freq=5000,
n_eval_episodes=10,
save_freq=10000,
verbose=1,
seed=123,
tensorboard_log='./tb_logs',
vec_env_type='subproc',
n_envs=4
)
# Custom hyperparameters
custom_hyperparams = {
'learning_rate': 0.0003,
'buffer_size': 50000,
'batch_size': 64,
'tau': 0.02,
'gamma': 0.98
}
# Custom environment kwargs
env_kwargs = {
'render_mode': None,
'max_episode_steps': 200
}
# Create experiment manager with custom settings
exp_manager = ExperimentManager(
args=args,
algo='sac',
env_id='Pendulum-v1',
log_folder='./logs',
tensorboard_log='./tb_logs',
n_timesteps=50000,
eval_freq=5000,
n_eval_episodes=10,
save_freq=10000,
hyperparams=custom_hyperparams,
env_kwargs=env_kwargs,
vec_env_type='subproc',
n_envs=4,
seed=123,
show_progress=True
)
# Setup and train
model = exp_manager.setup_experiment()
exp_manager.learn(model)
exp_manager.save_trained_model(model)import argparse
from rl_zoo3.exp_manager import ExperimentManager
# Setup for hyperparameter optimization
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=10000,
eval_freq=2000,
n_eval_episodes=5,
verbose=0, # Reduce verbosity for optimization
seed=42
)
# Create experiment manager for optimization
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./optim_logs',
n_timesteps=10000,
eval_freq=2000,
optimize_hyperparameters=True,
n_trials=50,
n_jobs=2,
sampler='tpe',
pruner='median',
study_name='ppo_cartpole_optimization',
seed=42
)
# Run hyperparameter optimization
exp_manager.hyperparameters_optimization()import argparse
from rl_zoo3.exp_manager import ExperimentManager
# Setup for loading pre-trained model
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=20000, # Additional training steps
eval_freq=1000,
verbose=1,
seed=42
)
# Create experiment manager with trained agent
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./logs',
trained_agent='./logs/ppo/CartPole-v1_1/best_model.zip',
n_timesteps=20000,
eval_freq=1000,
seed=42
)
# Load model and continue training
model = exp_manager.setup_experiment() # This will load the trained agent
exp_manager.learn(model) # Continue training for additional timesteps
exp_manager.save_trained_model(model)Install with Tessl CLI
npx tessl i tessl/pypi-rl-zoo3