CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-rl-zoo3

A Training Framework for Stable Baselines3 Reinforcement Learning Agents

Pending
Overview
Eval results
Files

core-utilities.mddocs/

Core Utilities

Essential utilities for working with RL environments, models, and hyperparameters. These functions form the foundation of the RL Zoo3 framework, providing core capabilities for algorithm selection, environment creation, model management, and parameter scheduling.

Core Imports

from rl_zoo3 import ALGOS, create_test_env, get_trained_models, linear_schedule
from rl_zoo3.utils import (
    get_model_path,
    get_saved_hyperparams, 
    get_latest_run_id,
    get_wrapper_class,
    get_class_by_name,
    flatten_dict_observations,
    get_callback_list
)

Capabilities

Algorithm Dictionary

Central registry mapping algorithm names to their Stable Baselines3 classes, enabling dynamic algorithm selection and instantiation.

ALGOS: dict[str, type[BaseAlgorithm]]

The ALGOS dictionary contains mappings for:

  • "a2c": A2C (Advantage Actor-Critic)
  • "ddpg": DDPG (Deep Deterministic Policy Gradient)
  • "dqn": DQN (Deep Q-Network)
  • "ppo": PPO (Proximal Policy Optimization)
  • "sac": SAC (Soft Actor-Critic)
  • "td3": TD3 (Twin Delayed Deep Deterministic Policy Gradient)
  • "ars": ARS (Augmented Random Search)
  • "crossq": CrossQ
  • "qrdqn": QRDQN (Quantile Regression DQN)
  • "tqc": TQC (Truncated Quantile Critics)
  • "trpo": TRPO (Trust Region Policy Optimization)
  • "ppo_lstm": RecurrentPPO (PPO with LSTM)

Usage example:

from rl_zoo3 import ALGOS
from stable_baselines3.common.env_util import make_vec_env

# Get the PPO algorithm class
ppo_class = ALGOS["ppo"]

# Create environment and model
env = make_vec_env("CartPole-v1", n_envs=1)
model = ppo_class("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

Environment Creation

Creates vectorized test environments with proper wrappers, normalization, and configuration for evaluation and testing.

def create_test_env(
    env_id: str,
    n_envs: int = 1,
    stats_path: Optional[str] = None,
    seed: int = 0,
    log_dir: Optional[str] = None,
    should_render: bool = True,
    hyperparams: Optional[dict[str, Any]] = None,
    env_kwargs: Optional[dict[str, Any]] = None,
    vec_env_cls: Optional[type[VecEnv]] = None,
    vec_env_kwargs: Optional[dict[str, Any]] = None
) -> VecEnv:
    """
    Create a wrapped, monitored VecEnv for testing.
    
    Parameters:
    - env_id: Environment identifier (e.g., 'CartPole-v1')
    - n_envs: Number of parallel environments
    - stats_path: Path to VecNormalize statistics file
    - seed: Random seed for reproducibility
    - log_dir: Directory for logging environment interactions
    - should_render: Whether to enable rendering (for PyBullet envs)
    - hyperparams: Hyperparameters dict containing env_wrapper settings
    - env_kwargs: Additional keyword arguments for environment creation
    - vec_env_cls: VecEnv class constructor to use
    - vec_env_kwargs: Keyword arguments for VecEnv constructor
    
    Returns:
    VecEnv: Configured vectorized environment ready for testing
    """

Usage example:

from rl_zoo3 import create_test_env

# Create test environment
env = create_test_env(
    env_id="CartPole-v1",
    n_envs=4,
    seed=42,
    should_render=False
)

# Use with trained model
obs = env.reset()
for _ in range(1000):
    action = env.action_space.sample()  # Random actions for demo
    obs, rewards, dones, infos = env.step(action)

Model Discovery

Discovers and lists trained models from log directories, returning model paths and metadata for loading and evaluation.

def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]:
    """
    Get a dictionary of trained models from a log folder.
    
    Parameters:
    - log_folder: Path to the directory containing trained models
    
    Returns:
    dict: Dictionary mapping model names to (model_path, stats_path) tuples
    """
def get_hf_trained_models(
    organization: str = "sb3",
    check_filename: bool = False
) -> dict[str, tuple[str, str]]:
    """
    Get trained models from HuggingFace Hub.
    
    Parameters:
    - organization: HuggingFace organization name
    - check_filename: Whether to validate model filenames
    
    Returns:
    dict: Dictionary mapping model names to (repo_id, filename) tuples
    """

Usage example:

from rl_zoo3 import get_trained_models, get_hf_trained_models

# Get locally trained models
local_models = get_trained_models("./logs")
print("Local models:", local_models)

# Get models from HuggingFace Hub
hf_models = get_hf_trained_models(organization="sb3")
print("HF models:", list(hf_models.keys())[:5])  # Show first 5

Run Management

Utilities for managing training runs and finding the latest run ID for continued training or evaluation.

def get_latest_run_id(log_path: str, env_name: str) -> int:
    """
    Get the latest run ID for a given environment.
    
    Parameters:
    - log_path: Path to the log directory
    - env_name: Environment name
    
    Returns:
    int: Latest run ID (0-indexed)
    """
def get_model_path(
    exp_id: int,
    folder: str,
    algo: str,
    env_name: str,
    load_best: bool = False,
    load_checkpoint: Optional[str] = None,
    load_last_checkpoint: bool = False
) -> tuple[str, str, str]:
    """
    Get the path to a trained model and related information.
    
    Parameters:
    - exp_id: Experiment ID (0 for latest)
    - folder: Log folder path
    - algo: Algorithm name
    - env_name: Environment name
    - load_best: Whether to load the best model
    - load_checkpoint: Specific checkpoint to load (e.g., "100000")
    - load_last_checkpoint: Whether to load the last checkpoint
    
    Returns:
    tuple[str, str, str]: (name_prefix, model_path, log_path)
    """

Hyperparameter Management

Loading and managing saved hyperparameters and normalization statistics from trained models.

def get_saved_hyperparams(
    stats_path: str,
    norm_reward: bool = False,
    test_mode: bool = False
) -> tuple[dict[str, Any], str]:
    """
    Load saved hyperparameters from a stats file.
    
    Parameters:
    - stats_path: Path to the stats.pkl file
    - norm_reward: Whether reward normalization was used
    - test_mode: Whether in test mode
    
    Returns:
    tuple: (hyperparams_dict, stats_path)
    """

Usage example:

from rl_zoo3 import get_saved_hyperparams

# Load hyperparameters from trained model
hyperparams, stats_path = get_saved_hyperparams("./logs/ppo/CartPole-v1_1/")
print("Loaded hyperparams:", hyperparams)

Parameter Scheduling

Linear scheduling functions for hyperparameters like learning rate that need to change during training.

def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule:
    """
    Create a linear schedule for a hyperparameter.
    
    Parameters:
    - initial_value: Initial value (float) or string representation
    
    Returns:
    SimpleLinearSchedule: Callable schedule object
    """
class SimpleLinearSchedule:
    """
    Linear parameter scheduling class.
    """
    def __init__(self, initial_value: float): ...
    def __call__(self, progress_remaining: float) -> float: ...

Usage example:

from rl_zoo3 import linear_schedule, ALGOS
from stable_baselines3.common.env_util import make_vec_env

# Create linear learning rate schedule
lr_schedule = linear_schedule(0.001)

# Use with PPO
env = make_vec_env("CartPole-v1", n_envs=1)
model = ALGOS["ppo"](
    "MlpPolicy", 
    env, 
    learning_rate=lr_schedule,
    verbose=1
)

Environment Wrapper Utilities

Utilities for extracting and applying environment wrappers from hyperparameter configurations.

def get_wrapper_class(
    hyperparams: dict[str, Any],
    key: str = "env_wrapper"
) -> Optional[Callable[[gym.Env], gym.Env]]:
    """
    Get one or more Gym environment wrapper class from hyperparams.
    
    Parameters:
    - hyperparams: Hyperparameters dictionary
    - key: Key in hyperparams containing wrapper specification
    
    Returns:
    Optional wrapper class or wrapper chain function
    """
def get_class_by_name(name: str) -> type:
    """
    Dynamically import a class by its name.
    
    Parameters:
    - name: Full class name (e.g., 'stable_baselines3.PPO')
    
    Returns:
    type: The imported class
    """
def flatten_dict_observations(env: gym.Env) -> gym.Env:
    """
    Flatten dictionary observation spaces.
    
    Parameters:
    - env: Environment with Dict observation space
    
    Returns:
    gym.Env: Environment with flattened observation space
    """

Callback Management

Utilities for creating and managing training callbacks from hyperparameter configurations.

def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
    """
    Get callback list from hyperparams.
    
    Parameters:
    - hyperparams: Hyperparameters dictionary containing callback specifications
    
    Returns:
    list[BaseCallback]: List of configured callbacks
    """

Usage example:

from rl_zoo3.utils import get_callback_list

# Hyperparams with callbacks
hyperparams = {
    "callback": "stable_baselines3.common.callbacks.CheckpointCallback",
    "callback_kwargs": {"save_freq": 1000, "save_path": "./checkpoints/"}
}

# Get callback list
callbacks = get_callback_list(hyperparams)
print(f"Created {len(callbacks)} callbacks")

Utility Classes

class StoreDict(argparse.Action):
    """
    Argparse action for storing dictionary parameters.
    Converts key=value pairs to dictionary entries.
    """
    def __call__(self, parser, namespace, values, option_string=None): ...

Usage example:

import argparse
from rl_zoo3.utils import StoreDict

parser = argparse.ArgumentParser()
parser.add_argument("--env-kwargs", type=str, nargs="+", action=StoreDict)
args = parser.parse_args(["--env-kwargs", "render_mode=human", "max_steps=1000"])
print(args.env_kwargs)  # {'render_mode': 'human', 'max_steps': '1000'}

Install with Tessl CLI

npx tessl i tessl/pypi-rl-zoo3

docs

callbacks.md

core-utilities.md

experiment-management.md

hub-integration.md

hyperparameter-optimization.md

index.md

plotting.md

wrappers.md

tile.json