or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md
tile.json

tessl/pypi-rl-zoo3

A Training Framework for Stable Baselines3 Reinforcement Learning Agents

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/rl-zoo3@2.7.x

To install, run

npx @tessl/cli install tessl/pypi-rl-zoo3@2.7.0

index.mddocs/

RL Zoo3

RL Baselines3 Zoo is a comprehensive training framework for reinforcement learning agents using Stable Baselines3. It provides scripts for training, evaluating, and tuning hyperparameters for RL agents, along with a collection of pre-tuned hyperparameters for common environments and algorithms.

Package Information

  • Package Name: rl-zoo3
  • Language: Python
  • Installation: pip install rl_zoo3

Core Imports

import rl_zoo3

Common imports for core functionality:

from rl_zoo3 import ALGOS, create_test_env, get_trained_models
from rl_zoo3.exp_manager import ExperimentManager
from rl_zoo3.utils import get_saved_hyperparams, linear_schedule

Basic Usage

Quick Training Example

from rl_zoo3.train import train
import sys

# Set up command line arguments for training
sys.argv = [
    'train.py',
    '--algo', 'ppo',
    '--env', 'CartPole-v1',
    '--n-timesteps', '10000'
]

# Train an agent
train()

Using ExperimentManager

import argparse
from rl_zoo3.exp_manager import ExperimentManager

# Create arguments namespace
args = argparse.Namespace(
    algo='ppo',
    env='CartPole-v1',
    n_timesteps=10000,
    eval_freq=1000,
    n_eval_episodes=5,
    save_freq=-1,
    verbose=1
)

# Create and setup experiment
exp_manager = ExperimentManager(
    args=args,
    algo='ppo',
    env_id='CartPole-v1',
    log_folder='./logs',
    n_timesteps=10000
)

# Setup experiment (creates model and environment)
model = exp_manager.setup_experiment()

# Train the model
model.learn(total_timesteps=10000)

Loading and Evaluating Trained Models

from rl_zoo3 import get_trained_models, create_test_env
from rl_zoo3.enjoy import enjoy
import sys

# Get available trained models
trained_models = get_trained_models('./logs')
print("Available models:", trained_models)

# Set up command line arguments for evaluation
sys.argv = [
    'enjoy.py',
    '--algo', 'ppo',
    '--env', 'CartPole-v1',
    '--folder', './logs',
    '--n-timesteps', '1000'
]

# Evaluate the trained agent
enjoy()

Architecture

RL Zoo3 is built around several key components:

  • ExperimentManager: Central orchestrator for training experiments, handles model creation, environment setup, hyperparameter loading, and training coordination
  • Algorithm Dictionary (ALGOS): Maps algorithm names to their Stable Baselines3 classes, supporting A2C, PPO, SAC, TD3, DQN, and more
  • Utilities: Core functions for environment creation, model loading, hyperparameter management, and file operations
  • Callbacks: Custom training callbacks for evaluation, hyperparameter optimization, and logging
  • Wrappers: Environment wrappers for observation processing, reward modification, and training optimization
  • Plotting: Visualization tools for training curves, evaluation results, and performance analysis

The framework integrates with Optuna for hyperparameter optimization, HuggingFace Hub for model sharing, and supports multiple environment libraries including OpenAI Gym, Atari, MuJoCo, and PyBullet.

Capabilities

Core Utilities

Essential utilities for working with RL environments, models, and hyperparameters. Includes algorithm mapping, environment creation, model loading, hyperparameter management, and scheduling functions.

ALGOS: dict[str, type[BaseAlgorithm]]

def create_test_env(
    env_id: str,
    n_envs: int = 1,
    stats_path: Optional[str] = None,
    seed: int = 0,
    log_dir: Optional[str] = None,
    should_render: bool = True,
    hyperparams: Optional[dict[str, Any]] = None,
    env_kwargs: Optional[dict[str, Any]] = None,
    vec_env_cls: Optional[type[VecEnv]] = None,
    vec_env_kwargs: Optional[dict[str, Any]] = None
) -> VecEnv: ...

def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]: ...

def get_saved_hyperparams(
    stats_path: str,
    norm_reward: bool = False,
    test_mode: bool = False
) -> tuple[dict[str, Any], str]: ...

def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule: ...

Core Utilities

Experiment Management

Comprehensive experiment orchestration through the ExperimentManager class, handling training workflows, hyperparameter optimization, environment setup, and model coordination.

class ExperimentManager:
    def __init__(
        self,
        args: argparse.Namespace,
        algo: str,
        env_id: str,
        log_folder: str,
        tensorboard_log: str = "",
        n_timesteps: int = 0,
        eval_freq: int = 10000,
        n_eval_episodes: int = 5,
        save_freq: int = -1,
        hyperparams: Optional[dict[str, Any]] = None,
        env_kwargs: Optional[dict[str, Any]] = None,
        **kwargs
    ): ...
    
    def setup_experiment(self) -> BaseAlgorithm: ...
    def learn(self, model: BaseAlgorithm) -> None: ...
    def save_trained_model(self, model: BaseAlgorithm) -> None: ...

Experiment Management

Training Callbacks

Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. Includes specialized callbacks for Optuna trials, VecNormalize saving, and parallel training.

class TrialEvalCallback(EvalCallback):
    def __init__(
        self,
        eval_env: VecEnv,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        **kwargs
    ): ...

class SaveVecNormalizeCallback(BaseCallback):
    def __init__(self, save_freq: int, save_path: str, **kwargs): ...

class ParallelTrainCallback(BaseCallback): ...
class RawStatisticsCallback(BaseCallback): ...

Training Callbacks

Environment Wrappers

Custom Gymnasium environment wrappers for observation processing, reward modification, action manipulation, and training optimization. Includes wrappers for success truncation, action noise, history tracking, and frame skipping.

class TruncatedOnSuccessWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): ...

class ActionNoiseWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, noise_std: float = 0.1): ...

class HistoryWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, horizon: int = 2): ...

class DelayedRewardWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, delay: int = 10): ...

Environment Wrappers

Hyperparameter Optimization

Hyperparameter sampling and optimization utilities using Optuna. Includes algorithm-specific parameter samplers and conversion functions for different RL algorithms.

def sample_ppo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def sample_sac_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def sample_dqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...
def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...

Hyperparameter Optimization

Plotting and Visualization

Comprehensive plotting tools for training curves, evaluation results, and performance analysis. Includes functions for plotting from log files, training progress, and generating publication-quality plots.

def plot_train(): ...
def plot_from_file(): ...
def all_plots(): ...
def normalize_score(score: np.ndarray, env_id: str) -> np.ndarray: ...

Plotting and Visualization

HuggingFace Hub Integration

Model sharing and loading through HuggingFace Hub integration. Includes functions for uploading trained models, downloading pre-trained models, and generating model cards.

def package_to_hub(
    model: BaseAlgorithm,
    model_name: str,
    repo_id: str,
    commit_message: str = "Add model",
    **kwargs
) -> str: ...

def download_from_hub(
    repo_id: str,
    filename: str,
    **kwargs
) -> str: ...

def generate_model_card(
    model: BaseAlgorithm,
    env_id: str,
    **kwargs
) -> str: ...

HuggingFace Hub Integration

Command Line Scripts

RL Zoo3 provides several command-line entry points:

# Main CLI entry point
rl_zoo3 train --algo ppo --env CartPole-v1
rl_zoo3 enjoy --algo ppo --env CartPole-v1 --folder logs/
rl_zoo3 plot_train --log-dir logs/
rl_zoo3 plot_from_file --log-dir logs/
rl_zoo3 all_plots --log-dir logs/

# Direct script execution
python -m rl_zoo3.train --algo ppo --env CartPole-v1
python -m rl_zoo3.enjoy --algo ppo --env CartPole-v1 --folder logs/

Supported Algorithms

RL Zoo3 supports the following reinforcement learning algorithms:

  • A2C: Advantage Actor-Critic
  • DDPG: Deep Deterministic Policy Gradient
  • DQN: Deep Q-Network
  • PPO: Proximal Policy Optimization
  • SAC: Soft Actor-Critic
  • TD3: Twin Delayed Deep Deterministic Policy Gradient
  • ARS: Augmented Random Search
  • CrossQ: CrossQ algorithm
  • QRDQN: Quantile Regression DQN
  • TQC: Truncated Quantile Critics
  • TRPO: Trust Region Policy Optimization
  • PPO_LSTM: PPO with LSTM policy

Types

from typing import Any, Callable, Optional, Union
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
import gymnasium as gym
import optuna
import torch as th

SimpleLinearSchedule = type  # Linear parameter scheduling class
StoreDict = type  # Argparse action for storing dict parameters

# Common type aliases
EnvironmentName = str
ModelName = str
HyperparamDict = dict[str, Any]
CallbackList = list[BaseCallback]
WrapperClass = Callable[[gym.Env], gym.Env]