A Training Framework for Stable Baselines3 Reinforcement Learning Agents
npx @tessl/cli install tessl/pypi-rl-zoo3@2.7.0RL Baselines3 Zoo is a comprehensive training framework for reinforcement learning agents using Stable Baselines3. It provides scripts for training, evaluating, and tuning hyperparameters for RL agents, along with a collection of pre-tuned hyperparameters for common environments and algorithms.
pip install rl_zoo3import rl_zoo3Common imports for core functionality:
from rl_zoo3 import ALGOS, create_test_env, get_trained_models
from rl_zoo3.exp_manager import ExperimentManager
from rl_zoo3.utils import get_saved_hyperparams, linear_schedulefrom rl_zoo3.train import train
import sys
# Set up command line arguments for training
sys.argv = [
'train.py',
'--algo', 'ppo',
'--env', 'CartPole-v1',
'--n-timesteps', '10000'
]
# Train an agent
train()import argparse
from rl_zoo3.exp_manager import ExperimentManager
# Create arguments namespace
args = argparse.Namespace(
algo='ppo',
env='CartPole-v1',
n_timesteps=10000,
eval_freq=1000,
n_eval_episodes=5,
save_freq=-1,
verbose=1
)
# Create and setup experiment
exp_manager = ExperimentManager(
args=args,
algo='ppo',
env_id='CartPole-v1',
log_folder='./logs',
n_timesteps=10000
)
# Setup experiment (creates model and environment)
model = exp_manager.setup_experiment()
# Train the model
model.learn(total_timesteps=10000)from rl_zoo3 import get_trained_models, create_test_env
from rl_zoo3.enjoy import enjoy
import sys
# Get available trained models
trained_models = get_trained_models('./logs')
print("Available models:", trained_models)
# Set up command line arguments for evaluation
sys.argv = [
'enjoy.py',
'--algo', 'ppo',
'--env', 'CartPole-v1',
'--folder', './logs',
'--n-timesteps', '1000'
]
# Evaluate the trained agent
enjoy()RL Zoo3 is built around several key components:
The framework integrates with Optuna for hyperparameter optimization, HuggingFace Hub for model sharing, and supports multiple environment libraries including OpenAI Gym, Atari, MuJoCo, and PyBullet.
Essential utilities for working with RL environments, models, and hyperparameters. Includes algorithm mapping, environment creation, model loading, hyperparameter management, and scheduling functions.
ALGOS: dict[str, type[BaseAlgorithm]]
def create_test_env(
env_id: str,
n_envs: int = 1,
stats_path: Optional[str] = None,
seed: int = 0,
log_dir: Optional[str] = None,
should_render: bool = True,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[type[VecEnv]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None
) -> VecEnv: ...
def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]: ...
def get_saved_hyperparams(
stats_path: str,
norm_reward: bool = False,
test_mode: bool = False
) -> tuple[dict[str, Any], str]: ...
def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule: ...Comprehensive experiment orchestration through the ExperimentManager class, handling training workflows, hyperparameter optimization, environment setup, and model coordination.
class ExperimentManager:
def __init__(
self,
args: argparse.Namespace,
algo: str,
env_id: str,
log_folder: str,
tensorboard_log: str = "",
n_timesteps: int = 0,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
save_freq: int = -1,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
**kwargs
): ...
def setup_experiment(self) -> BaseAlgorithm: ...
def learn(self, model: BaseAlgorithm) -> None: ...
def save_trained_model(self, model: BaseAlgorithm) -> None: ...Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. Includes specialized callbacks for Optuna trials, VecNormalize saving, and parallel training.
class TrialEvalCallback(EvalCallback):
def __init__(
self,
eval_env: VecEnv,
trial: optuna.Trial,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
**kwargs
): ...
class SaveVecNormalizeCallback(BaseCallback):
def __init__(self, save_freq: int, save_path: str, **kwargs): ...
class ParallelTrainCallback(BaseCallback): ...
class RawStatisticsCallback(BaseCallback): ...Custom Gymnasium environment wrappers for observation processing, reward modification, action manipulation, and training optimization. Includes wrappers for success truncation, action noise, history tracking, and frame skipping.
class TruncatedOnSuccessWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): ...
class ActionNoiseWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, noise_std: float = 0.1): ...
class HistoryWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, horizon: int = 2): ...
class DelayedRewardWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, delay: int = 10): ...Hyperparameter sampling and optimization utilities using Optuna. Includes algorithm-specific parameter samplers and conversion functions for different RL algorithms.
def sample_ppo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def sample_sac_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def sample_dqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...
def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...Comprehensive plotting tools for training curves, evaluation results, and performance analysis. Includes functions for plotting from log files, training progress, and generating publication-quality plots.
def plot_train(): ...
def plot_from_file(): ...
def all_plots(): ...
def normalize_score(score: np.ndarray, env_id: str) -> np.ndarray: ...Model sharing and loading through HuggingFace Hub integration. Includes functions for uploading trained models, downloading pre-trained models, and generating model cards.
def package_to_hub(
model: BaseAlgorithm,
model_name: str,
repo_id: str,
commit_message: str = "Add model",
**kwargs
) -> str: ...
def download_from_hub(
repo_id: str,
filename: str,
**kwargs
) -> str: ...
def generate_model_card(
model: BaseAlgorithm,
env_id: str,
**kwargs
) -> str: ...RL Zoo3 provides several command-line entry points:
# Main CLI entry point
rl_zoo3 train --algo ppo --env CartPole-v1
rl_zoo3 enjoy --algo ppo --env CartPole-v1 --folder logs/
rl_zoo3 plot_train --log-dir logs/
rl_zoo3 plot_from_file --log-dir logs/
rl_zoo3 all_plots --log-dir logs/
# Direct script execution
python -m rl_zoo3.train --algo ppo --env CartPole-v1
python -m rl_zoo3.enjoy --algo ppo --env CartPole-v1 --folder logs/RL Zoo3 supports the following reinforcement learning algorithms:
from typing import Any, Callable, Optional, Union
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
import gymnasium as gym
import optuna
import torch as th
SimpleLinearSchedule = type # Linear parameter scheduling class
StoreDict = type # Argparse action for storing dict parameters
# Common type aliases
EnvironmentName = str
ModelName = str
HyperparamDict = dict[str, Any]
CallbackList = list[BaseCallback]
WrapperClass = Callable[[gym.Env], gym.Env]