A Training Framework for Stable Baselines3 Reinforcement Learning Agents
—
Essential utilities for working with RL environments, models, and hyperparameters. These functions form the foundation of the RL Zoo3 framework, providing core capabilities for algorithm selection, environment creation, model management, and parameter scheduling.
from rl_zoo3 import ALGOS, create_test_env, get_trained_models, linear_schedule
from rl_zoo3.utils import (
get_model_path,
get_saved_hyperparams,
get_latest_run_id,
get_wrapper_class,
get_class_by_name,
flatten_dict_observations,
get_callback_list
)Central registry mapping algorithm names to their Stable Baselines3 classes, enabling dynamic algorithm selection and instantiation.
ALGOS: dict[str, type[BaseAlgorithm]]The ALGOS dictionary contains mappings for:
"a2c": A2C (Advantage Actor-Critic)"ddpg": DDPG (Deep Deterministic Policy Gradient)"dqn": DQN (Deep Q-Network)"ppo": PPO (Proximal Policy Optimization)"sac": SAC (Soft Actor-Critic)"td3": TD3 (Twin Delayed Deep Deterministic Policy Gradient)"ars": ARS (Augmented Random Search)"crossq": CrossQ"qrdqn": QRDQN (Quantile Regression DQN)"tqc": TQC (Truncated Quantile Critics)"trpo": TRPO (Trust Region Policy Optimization)"ppo_lstm": RecurrentPPO (PPO with LSTM)Usage example:
from rl_zoo3 import ALGOS
from stable_baselines3.common.env_util import make_vec_env
# Get the PPO algorithm class
ppo_class = ALGOS["ppo"]
# Create environment and model
env = make_vec_env("CartPole-v1", n_envs=1)
model = ppo_class("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)Creates vectorized test environments with proper wrappers, normalization, and configuration for evaluation and testing.
def create_test_env(
env_id: str,
n_envs: int = 1,
stats_path: Optional[str] = None,
seed: int = 0,
log_dir: Optional[str] = None,
should_render: bool = True,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[type[VecEnv]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for testing.
Parameters:
- env_id: Environment identifier (e.g., 'CartPole-v1')
- n_envs: Number of parallel environments
- stats_path: Path to VecNormalize statistics file
- seed: Random seed for reproducibility
- log_dir: Directory for logging environment interactions
- should_render: Whether to enable rendering (for PyBullet envs)
- hyperparams: Hyperparameters dict containing env_wrapper settings
- env_kwargs: Additional keyword arguments for environment creation
- vec_env_cls: VecEnv class constructor to use
- vec_env_kwargs: Keyword arguments for VecEnv constructor
Returns:
VecEnv: Configured vectorized environment ready for testing
"""Usage example:
from rl_zoo3 import create_test_env
# Create test environment
env = create_test_env(
env_id="CartPole-v1",
n_envs=4,
seed=42,
should_render=False
)
# Use with trained model
obs = env.reset()
for _ in range(1000):
action = env.action_space.sample() # Random actions for demo
obs, rewards, dones, infos = env.step(action)Discovers and lists trained models from log directories, returning model paths and metadata for loading and evaluation.
def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]:
"""
Get a dictionary of trained models from a log folder.
Parameters:
- log_folder: Path to the directory containing trained models
Returns:
dict: Dictionary mapping model names to (model_path, stats_path) tuples
"""def get_hf_trained_models(
organization: str = "sb3",
check_filename: bool = False
) -> dict[str, tuple[str, str]]:
"""
Get trained models from HuggingFace Hub.
Parameters:
- organization: HuggingFace organization name
- check_filename: Whether to validate model filenames
Returns:
dict: Dictionary mapping model names to (repo_id, filename) tuples
"""Usage example:
from rl_zoo3 import get_trained_models, get_hf_trained_models
# Get locally trained models
local_models = get_trained_models("./logs")
print("Local models:", local_models)
# Get models from HuggingFace Hub
hf_models = get_hf_trained_models(organization="sb3")
print("HF models:", list(hf_models.keys())[:5]) # Show first 5Utilities for managing training runs and finding the latest run ID for continued training or evaluation.
def get_latest_run_id(log_path: str, env_name: str) -> int:
"""
Get the latest run ID for a given environment.
Parameters:
- log_path: Path to the log directory
- env_name: Environment name
Returns:
int: Latest run ID (0-indexed)
"""def get_model_path(
exp_id: int,
folder: str,
algo: str,
env_name: str,
load_best: bool = False,
load_checkpoint: Optional[str] = None,
load_last_checkpoint: bool = False
) -> tuple[str, str, str]:
"""
Get the path to a trained model and related information.
Parameters:
- exp_id: Experiment ID (0 for latest)
- folder: Log folder path
- algo: Algorithm name
- env_name: Environment name
- load_best: Whether to load the best model
- load_checkpoint: Specific checkpoint to load (e.g., "100000")
- load_last_checkpoint: Whether to load the last checkpoint
Returns:
tuple[str, str, str]: (name_prefix, model_path, log_path)
"""Loading and managing saved hyperparameters and normalization statistics from trained models.
def get_saved_hyperparams(
stats_path: str,
norm_reward: bool = False,
test_mode: bool = False
) -> tuple[dict[str, Any], str]:
"""
Load saved hyperparameters from a stats file.
Parameters:
- stats_path: Path to the stats.pkl file
- norm_reward: Whether reward normalization was used
- test_mode: Whether in test mode
Returns:
tuple: (hyperparams_dict, stats_path)
"""Usage example:
from rl_zoo3 import get_saved_hyperparams
# Load hyperparameters from trained model
hyperparams, stats_path = get_saved_hyperparams("./logs/ppo/CartPole-v1_1/")
print("Loaded hyperparams:", hyperparams)Linear scheduling functions for hyperparameters like learning rate that need to change during training.
def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule:
"""
Create a linear schedule for a hyperparameter.
Parameters:
- initial_value: Initial value (float) or string representation
Returns:
SimpleLinearSchedule: Callable schedule object
"""class SimpleLinearSchedule:
"""
Linear parameter scheduling class.
"""
def __init__(self, initial_value: float): ...
def __call__(self, progress_remaining: float) -> float: ...Usage example:
from rl_zoo3 import linear_schedule, ALGOS
from stable_baselines3.common.env_util import make_vec_env
# Create linear learning rate schedule
lr_schedule = linear_schedule(0.001)
# Use with PPO
env = make_vec_env("CartPole-v1", n_envs=1)
model = ALGOS["ppo"](
"MlpPolicy",
env,
learning_rate=lr_schedule,
verbose=1
)Utilities for extracting and applying environment wrappers from hyperparameter configurations.
def get_wrapper_class(
hyperparams: dict[str, Any],
key: str = "env_wrapper"
) -> Optional[Callable[[gym.Env], gym.Env]]:
"""
Get one or more Gym environment wrapper class from hyperparams.
Parameters:
- hyperparams: Hyperparameters dictionary
- key: Key in hyperparams containing wrapper specification
Returns:
Optional wrapper class or wrapper chain function
"""def get_class_by_name(name: str) -> type:
"""
Dynamically import a class by its name.
Parameters:
- name: Full class name (e.g., 'stable_baselines3.PPO')
Returns:
type: The imported class
"""def flatten_dict_observations(env: gym.Env) -> gym.Env:
"""
Flatten dictionary observation spaces.
Parameters:
- env: Environment with Dict observation space
Returns:
gym.Env: Environment with flattened observation space
"""Utilities for creating and managing training callbacks from hyperparameter configurations.
def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
"""
Get callback list from hyperparams.
Parameters:
- hyperparams: Hyperparameters dictionary containing callback specifications
Returns:
list[BaseCallback]: List of configured callbacks
"""Usage example:
from rl_zoo3.utils import get_callback_list
# Hyperparams with callbacks
hyperparams = {
"callback": "stable_baselines3.common.callbacks.CheckpointCallback",
"callback_kwargs": {"save_freq": 1000, "save_path": "./checkpoints/"}
}
# Get callback list
callbacks = get_callback_list(hyperparams)
print(f"Created {len(callbacks)} callbacks")class StoreDict(argparse.Action):
"""
Argparse action for storing dictionary parameters.
Converts key=value pairs to dictionary entries.
"""
def __call__(self, parser, namespace, values, option_string=None): ...Usage example:
import argparse
from rl_zoo3.utils import StoreDict
parser = argparse.ArgumentParser()
parser.add_argument("--env-kwargs", type=str, nargs="+", action=StoreDict)
args = parser.parse_args(["--env-kwargs", "render_mode=human", "max_steps=1000"])
print(args.env_kwargs) # {'render_mode': 'human', 'max_steps': '1000'}Install with Tessl CLI
npx tessl i tessl/pypi-rl-zoo3