A Training Framework for Stable Baselines3 Reinforcement Learning Agents
—
Hyperparameter sampling and optimization utilities using Optuna. Provides algorithm-specific parameter samplers, conversion functions, and distributed optimization support for finding optimal hyperparameters across different RL algorithms.
from rl_zoo3.hyperparams_opt import (
sample_ppo_params,
sample_sac_params,
sample_dqn_params,
sample_td3_params,
sample_a2c_params,
sample_ars_params,
convert_onpolicy_params,
convert_offpolicy_params,
convert_ars_params
)
import optuna
from typing import Any, dictFunctions for converting sampled hyperparameters into the format expected by different algorithm families.
def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
"""
Convert sampled hyperparameters for on-policy algorithms (PPO, A2C, TRPO).
Parameters:
- sampled_params: Raw hyperparameters from Optuna sampling
Returns:
dict: Converted hyperparameters ready for algorithm use
"""
def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
"""
Convert sampled hyperparameters for off-policy algorithms (SAC, TD3, DQN).
Parameters:
- sampled_params: Raw hyperparameters from Optuna sampling
Returns:
dict: Converted hyperparameters ready for algorithm use
"""
def convert_ars_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
"""
Convert sampled hyperparameters for ARS algorithm.
Parameters:
- sampled_params: Raw hyperparameters from Optuna sampling
Returns:
dict: Converted ARS-specific hyperparameters
"""Sampling functions for Proximal Policy Optimization hyperparameters.
def sample_ppo_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for PPO algorithm.
Parameters:
- trial: Optuna trial object for parameter sampling
- n_actions: Number of actions in the action space
- n_envs: Number of parallel environments
- additional_args: Additional algorithm-specific arguments
Returns:
dict: Sampled PPO hyperparameters including learning_rate, n_steps,
batch_size, n_epochs, gamma, gae_lambda, clip_range, ent_coef, etc.
"""
def sample_ppo_lstm_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for PPO with LSTM policy.
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled PPO-LSTM hyperparameters with LSTM-specific parameters
"""Sampling functions for Soft Actor-Critic hyperparameters.
def sample_sac_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for SAC algorithm.
Parameters:
- trial: Optuna trial object for parameter sampling
- n_actions: Number of actions in the action space
- n_envs: Number of parallel environments (typically 1 for SAC)
- additional_args: Additional algorithm-specific arguments
Returns:
dict: Sampled SAC hyperparameters including learning_rate, buffer_size,
batch_size, tau, gamma, train_freq, gradient_steps, ent_coef, etc.
"""Sampling functions for Deep Q-Network and its variants.
def sample_dqn_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for DQN algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of discrete actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled DQN hyperparameters including learning_rate, buffer_size,
batch_size, tau, gamma, train_freq, target_update_interval, etc.
"""
def sample_qrdqn_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for QR-DQN (Quantile Regression DQN).
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled QR-DQN hyperparameters with quantile-specific parameters
"""Sampling functions for Twin Delayed Deep Deterministic Policy Gradient.
def sample_td3_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for TD3 algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of continuous actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled TD3 hyperparameters including learning_rate, buffer_size,
batch_size, tau, gamma, train_freq, policy_delay, target_policy_noise, etc.
"""Sampling functions for Advantage Actor-Critic.
def sample_a2c_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for A2C algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of parallel environments
- additional_args: Additional arguments
Returns:
dict: Sampled A2C hyperparameters including learning_rate, n_steps,
gamma, gae_lambda, ent_coef, vf_coef, etc.
"""Sampling functions for Trust Region Policy Optimization.
def sample_trpo_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for TRPO algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled TRPO hyperparameters including learning_rate, n_steps,
batch_size, gamma, gae_lambda, cg_max_steps, target_kl, etc.
"""Sampling functions for Truncated Quantile Critics.
def sample_tqc_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for TQC algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled TQC hyperparameters with quantile critic parameters
"""Sampling functions for Augmented Random Search.
def sample_ars_params(
trial: optuna.Trial,
n_actions: int,
n_envs: int,
additional_args: dict
) -> dict[str, Any]:
"""
Sample hyperparameters for ARS algorithm.
Parameters:
- trial: Optuna trial object
- n_actions: Number of actions
- n_envs: Number of environments
- additional_args: Additional arguments
Returns:
dict: Sampled ARS hyperparameters including n_delta, n_top, learning_rate,
delta_std, zero_policy, etc.
"""Sampling functions for Hindsight Experience Replay parameters.
def sample_her_params(
trial: optuna.Trial,
hyperparams: dict[str, Any],
her_kwargs: dict[str, Any]
) -> dict[str, Any]:
"""
Sample hyperparameters for HER (Hindsight Experience Replay).
Parameters:
- trial: Optuna trial object
- hyperparams: Base algorithm hyperparameters
- her_kwargs: HER-specific keyword arguments
Returns:
dict: Updated hyperparameters with HER configuration
"""import optuna
from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params
from rl_zoo3.exp_manager import ExperimentManager
from rl_zoo3 import ALGOS
import argparse
def objective(trial):
# Sample hyperparameters
sampled_params = sample_ppo_params(
trial=trial,
n_actions=2, # CartPole has 2 actions
n_envs=4,
additional_args={}
)
# Convert parameters
hyperparams = convert_onpolicy_params(sampled_params)
# Create experiment manager
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=10000,
eval_freq=2000,
n_eval_episodes=5,
verbose=0
)
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./optim_logs',
hyperparams=hyperparams,
n_timesteps=10000,
eval_freq=2000
)
# Setup and train
model = exp_manager.setup_experiment()
exp_manager.learn(model)
# Return performance metric
# (In practice, this would be extracted from evaluation callback)
return 200.0 # Placeholder reward
# Run optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=20)
print("Best parameters:", study.best_params)
print("Best value:", study.best_value)import optuna
from rl_zoo3.hyperparams_opt import (
sample_ppo_params, sample_sac_params,
convert_onpolicy_params, convert_offpolicy_params
)
def multi_algo_objective(trial):
# Select algorithm
algo_name = trial.suggest_categorical('algorithm', ['ppo', 'sac'])
if algo_name == 'ppo':
sampled_params = sample_ppo_params(trial, n_actions=2, n_envs=4, additional_args={})
hyperparams = convert_onpolicy_params(sampled_params)
elif algo_name == 'sac':
sampled_params = sample_sac_params(trial, n_actions=1, n_envs=1, additional_args={})
hyperparams = convert_offpolicy_params(sampled_params)
# Create and train model with selected algorithm and parameters
# ... (training code similar to above example)
return performance_score
# Optimize across algorithms
study = optuna.create_study(direction='maximize')
study.optimize(multi_algo_objective, n_trials=50)import optuna
from rl_zoo3.exp_manager import ExperimentManager
def create_distributed_study():
# Create study with database storage for distributed optimization
study = optuna.create_study(
study_name='rl_zoo3_optimization',
storage='sqlite:///optuna_study.db',
direction='maximize',
load_if_exists=True
)
return study
def distributed_objective(trial):
# Sample parameters for chosen algorithm
algo = 'ppo' # Could be parameterized
if algo == 'ppo':
from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params
sampled_params = sample_ppo_params(trial, n_actions=4, n_envs=8, additional_args={})
hyperparams = convert_onpolicy_params(sampled_params)
# Create experiment manager with optimization settings
args = argparse.Namespace(
algo=algo,
env='LunarLander-v2',
n_timesteps=50000,
eval_freq=5000,
n_eval_episodes=10,
verbose=0,
seed=trial.suggest_int('seed', 0, 2**32-1)
)
exp_manager = ExperimentManager(
args=args,
algo=algo,
env_id='LunarLander-v2',
log_folder=f'./optim_logs/trial_{trial.number}',
hyperparams=hyperparams,
n_timesteps=50000,
eval_freq=5000
)
# Train and evaluate
model = exp_manager.setup_experiment()
exp_manager.learn(model)
# Extract performance (typically from evaluation callback)
return trial.suggest_float('mock_performance', -500, 500) # Placeholder
# Run distributed optimization
study = create_distributed_study()
study.optimize(distributed_objective, n_trials=10) # Each process runs 10 trialsimport optuna
from rl_zoo3.hyperparams_opt import convert_onpolicy_params
def sample_custom_ppo_params(trial, n_actions, n_envs, additional_args):
"""
Custom PPO parameter sampling with different ranges.
"""
# Learning rate with log-uniform distribution
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
# Batch size as power of 2
batch_size_exp = trial.suggest_int('batch_size_exp', 4, 8) # 2^4 to 2^8
batch_size = 2 ** batch_size_exp
# Number of steps
n_steps = trial.suggest_categorical('n_steps', [128, 256, 512, 1024, 2048])
# Ensure batch_size <= n_steps * n_envs
if batch_size > n_steps * n_envs:
batch_size = n_steps * n_envs
# Other hyperparameters
gamma = trial.suggest_float('gamma', 0.9, 0.9999)
gae_lambda = trial.suggest_float('gae_lambda', 0.8, 1.0)
clip_range = trial.suggest_float('clip_range', 0.1, 0.4)
ent_coef = trial.suggest_float('ent_coef', 1e-8, 1e-1, log=True)
return {
'learning_rate': learning_rate,
'n_steps': n_steps,
'batch_size': batch_size,
'gamma': gamma,
'gae_lambda': gae_lambda,
'clip_range': clip_range,
'ent_coef': ent_coef,
'n_epochs': trial.suggest_int('n_epochs', 3, 10),
'vf_coef': trial.suggest_float('vf_coef', 0.1, 1.0)
}
# Use custom sampling in optimization
def custom_objective(trial):
sampled_params = sample_custom_ppo_params(
trial, n_actions=4, n_envs=8, additional_args={}
)
hyperparams = convert_onpolicy_params(sampled_params)
# ... rest of training code
return performancefrom rl_zoo3.exp_manager import ExperimentManager
import argparse
# ExperimentManager handles optimization automatically
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=20000,
eval_freq=2000,
optimize_hyperparameters=True, # Enable optimization
n_trials=30,
n_jobs=2,
sampler='tpe',
pruner='median',
study_name='ppo_cartpole_study',
storage='sqlite:///ppo_optimization.db'
)
# Automatic hyperparameter optimization
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./optim_logs',
optimize_hyperparameters=True,
n_trials=30,
n_jobs=2,
sampler='tpe',
pruner='median'
)
# This will run the full optimization process
exp_manager.hyperparameters_optimization()The optimization system supports various Optuna samplers and pruners:
Samplers:
'tpe': Tree-structured Parzen Estimator (default, good for most cases)'random': Random sampling (baseline)'cmaes': CMA-ES (good for continuous parameters)Pruners:
'median': Median pruner (default, prunes below median performance)'successive_halving': Successive halving (aggressive pruning)'hyperband': Hyperband (adaptive resource allocation)'nop': No pruningInstall with Tessl CLI
npx tessl i tessl/pypi-rl-zoo3