or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

algorithms.mdcommon-framework.mdenvironments.mdher.mdindex.mdtraining-utilities.md
tile.json

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/stable-baselines3@2.7.x

To install, run

npx @tessl/cli install tessl/pypi-stable-baselines3@2.7.0

index.mddocs/

Stable Baselines3

Stable Baselines3 is a comprehensive Python library providing reliable implementations of state-of-the-art reinforcement learning algorithms using PyTorch. It offers a unified sklearn-like interface for multiple RL algorithms with extensive customization options, designed to facilitate both research and practical deployment of deep reinforcement learning solutions.

Package Information

  • Package Name: stable-baselines3
  • Language: Python
  • Installation: pip install stable-baselines3

Core Imports

from stable_baselines3 import PPO, A2C, SAC, TD3, DDPG, DQN

Common utilities and components:

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.utils import set_random_seed, get_system_info

Basic Usage

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback

# Create environment
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")

# Wrap in vectorized environment for training
env = DummyVecEnv([lambda: env])
eval_env = DummyVecEnv([lambda: eval_env])

# Create PPO agent
model = PPO("MlpPolicy", env, verbose=1)

# Set up evaluation callback
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/",
    log_path="./logs/",
    eval_freq=10000,
    deterministic=True,
    render=False
)

# Train the agent
model.learn(total_timesteps=100000, callback=eval_callback)

# Save the trained model
model.save("ppo_cartpole")

# Load and use the trained model
model = PPO.load("ppo_cartpole")

# Test the trained agent
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    if done:
        obs = env.reset()

Architecture

Stable Baselines3 follows a hierarchical architecture that promotes code reuse and extensibility:

  • Algorithm Classes: Implementation of specific RL algorithms (PPO, SAC, etc.)
  • Base Classes: Abstract base classes providing common functionality
    • BaseAlgorithm: Core training loop and model management
    • OnPolicyAlgorithm: Base for algorithms like A2C and PPO
    • OffPolicyAlgorithm: Base for algorithms like SAC, TD3, DDPG, DQN
  • Policies: Neural network architectures for different observation spaces
  • Buffers: Experience storage for training (rollout buffers, replay buffers)
  • Common Components: Utilities, callbacks, environment wrappers, and evaluation tools

This design enables consistent interfaces across algorithms while allowing for algorithm-specific optimizations.

Capabilities

Core Algorithms

Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options.

class PPO(OnPolicyAlgorithm):
    """Proximal Policy Optimization algorithm."""
    def __init__(self, policy, env, learning_rate=3e-4, n_steps=2048, batch_size=64, 
                 n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, **kwargs): ...

class A2C(OnPolicyAlgorithm):
    """Advantage Actor-Critic algorithm."""
    def __init__(self, policy, env, learning_rate=7e-4, n_steps=5, gamma=0.99, 
                 gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, **kwargs): ...

class SAC(OffPolicyAlgorithm):
    """Soft Actor-Critic algorithm."""
    def __init__(self, policy, env, learning_rate=3e-4, buffer_size=1000000, 
                 batch_size=256, tau=0.005, gamma=0.99, **kwargs): ...

class TD3(OffPolicyAlgorithm):
    """Twin Delayed Deep Deterministic Policy Gradient algorithm."""
    def __init__(self, policy, env, learning_rate=1e-3, buffer_size=1000000, 
                 batch_size=100, tau=0.005, gamma=0.99, **kwargs): ...

class DDPG(TD3):
    """Deep Deterministic Policy Gradient algorithm."""
    def __init__(self, policy, env, **kwargs): ...

class DQN(OffPolicyAlgorithm):
    """Deep Q-Network algorithm."""
    def __init__(self, policy, env, learning_rate=1e-4, buffer_size=1000000, 
                 batch_size=32, tau=1.0, gamma=0.99, **kwargs): ...

Core Algorithms

Common Framework

Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the library.

class BaseAlgorithm:
    """Abstract base class for all RL algorithms."""
    def learn(self, total_timesteps, callback=None, log_interval=4, 
              tb_log_name="run", reset_num_timesteps=True, progress_bar=False): ...
    def predict(self, observation, state=None, episode_start=None, deterministic=False): ...
    def save(self, path): ...
    @classmethod
    def load(cls, path, env=None, device="auto", **kwargs): ...

class BasePolicy:
    """Base policy class for all neural network policies."""
    def forward(self, obs, deterministic=False): ...
    def predict(self, observation, state=None, episode_start=None, deterministic=False): ...

class RolloutBuffer:
    """Buffer for on-policy algorithms."""
    def add(self, obs, actions, rewards, episode_starts, values, log_probs): ...
    def get(self, batch_size=None): ...

class ReplayBuffer:
    """Experience replay buffer for off-policy algorithms."""
    def add(self, obs, next_obs, actions, rewards, dones, infos): ...
    def sample(self, batch_size, env=None): ...

Common Framework

Vectorized Environments

Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks.

class DummyVecEnv:
    """Sequential vectorized environment."""
    def __init__(self, env_fns): ...
    def reset(self): ...
    def step(self, actions): ...

class SubprocVecEnv:
    """Multiprocessing vectorized environment."""
    def __init__(self, env_fns, start_method=None): ...

class VecNormalize:
    """Normalize observations and rewards."""
    def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, 
                 clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8): ...

class VecFrameStack:
    """Stack frames for recurrent policies."""
    def __init__(self, venv, n_stack, channels_order="last"): ...

Environments

Training Utilities

Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes.

class EvalCallback:
    """Evaluate model during training."""
    def __init__(self, eval_env, callback_on_new_best=None, n_eval_episodes=5, 
                 eval_freq=10000, log_path=None, best_model_save_path=None, 
                 deterministic=True, render=False, verbose=1): ...

class CheckpointCallback:
    """Save model at regular intervals."""
    def __init__(self, save_freq, save_path, name_prefix="rl_model", 
                 save_replay_buffer=False, save_vecnormalize=False, verbose=0): ...

class NormalActionNoise:
    """Gaussian action noise for exploration."""
    def __init__(self, mean, sigma): ...

def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, 
                   render=False, callback=None, reward_threshold=None, 
                   return_episode_rewards=False, warn=True, verbose=1): ...

Training Utilities

Hindsight Experience Replay (HER)

Implementation of Hindsight Experience Replay for goal-conditioned reinforcement learning, enabling learning from failed attempts by treating them as successful attempts toward different goals.

class HerReplayBuffer:
    """Replay buffer with Hindsight Experience Replay."""
    def __init__(self, buffer_size, observation_space, action_space, device="auto", 
                 n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True,
                 n_sampled_goal=4, goal_selection_strategy="future", wrapped_env=None): ...

class GoalSelectionStrategy:
    """Enumeration of goal selection strategies."""
    FUTURE = "future"
    FINAL = "final" 
    EPISODE = "episode"
    RANDOM = "random"

HER

System Information

Utilities for retrieving system and environment information for debugging and reproducibility.

def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
    """
    Retrieve system and python env info for the current system.
    
    Parameters:
    - print_info: Whether to print or not those infos
    
    Returns:
    Tuple containing dictionary with version info and formatted string
    """

Types

from typing import Union, Optional, Callable, Dict, Any, List, Tuple
import numpy as np
import torch
import gymnasium as gym

# Environment types
GymEnv = Union[gym.Env, gym.Wrapper]
VecEnv = Union[DummyVecEnv, SubprocVecEnv]

# Policy types  
Schedule = Callable[[float], float]
MaybeCallback = Union[None, Callable, List[Callable], "BaseCallback"]

# Algorithm-specific policy types
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.sac.policies import SACPolicy
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.dqn.policies import DQNPolicy

# Buffer types
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

# Noise types
from stable_baselines3.common.noise import ActionNoise

# Observation and action types
PyTorchObs = Union[torch.Tensor, Dict[str, torch.Tensor]]
TensorDict = Dict[str, torch.Tensor]

# Training frequency specification
TrainFreq = Union[int, Tuple[int, str]]