Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
npx @tessl/cli install tessl/pypi-stable-baselines3@2.7.0Stable Baselines3 is a comprehensive Python library providing reliable implementations of state-of-the-art reinforcement learning algorithms using PyTorch. It offers a unified sklearn-like interface for multiple RL algorithms with extensive customization options, designed to facilitate both research and practical deployment of deep reinforcement learning solutions.
pip install stable-baselines3from stable_baselines3 import PPO, A2C, SAC, TD3, DDPG, DQNCommon utilities and components:
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.utils import set_random_seed, get_system_infoimport gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
# Create environment
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
# Wrap in vectorized environment for training
env = DummyVecEnv([lambda: env])
eval_env = DummyVecEnv([lambda: eval_env])
# Create PPO agent
model = PPO("MlpPolicy", env, verbose=1)
# Set up evaluation callback
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./logs/",
log_path="./logs/",
eval_freq=10000,
deterministic=True,
render=False
)
# Train the agent
model.learn(total_timesteps=100000, callback=eval_callback)
# Save the trained model
model.save("ppo_cartpole")
# Load and use the trained model
model = PPO.load("ppo_cartpole")
# Test the trained agent
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
if done:
obs = env.reset()Stable Baselines3 follows a hierarchical architecture that promotes code reuse and extensibility:
BaseAlgorithm: Core training loop and model managementOnPolicyAlgorithm: Base for algorithms like A2C and PPOOffPolicyAlgorithm: Base for algorithms like SAC, TD3, DDPG, DQNThis design enables consistent interfaces across algorithms while allowing for algorithm-specific optimizations.
Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options.
class PPO(OnPolicyAlgorithm):
"""Proximal Policy Optimization algorithm."""
def __init__(self, policy, env, learning_rate=3e-4, n_steps=2048, batch_size=64,
n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, **kwargs): ...
class A2C(OnPolicyAlgorithm):
"""Advantage Actor-Critic algorithm."""
def __init__(self, policy, env, learning_rate=7e-4, n_steps=5, gamma=0.99,
gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, **kwargs): ...
class SAC(OffPolicyAlgorithm):
"""Soft Actor-Critic algorithm."""
def __init__(self, policy, env, learning_rate=3e-4, buffer_size=1000000,
batch_size=256, tau=0.005, gamma=0.99, **kwargs): ...
class TD3(OffPolicyAlgorithm):
"""Twin Delayed Deep Deterministic Policy Gradient algorithm."""
def __init__(self, policy, env, learning_rate=1e-3, buffer_size=1000000,
batch_size=100, tau=0.005, gamma=0.99, **kwargs): ...
class DDPG(TD3):
"""Deep Deterministic Policy Gradient algorithm."""
def __init__(self, policy, env, **kwargs): ...
class DQN(OffPolicyAlgorithm):
"""Deep Q-Network algorithm."""
def __init__(self, policy, env, learning_rate=1e-4, buffer_size=1000000,
batch_size=32, tau=1.0, gamma=0.99, **kwargs): ...Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the library.
class BaseAlgorithm:
"""Abstract base class for all RL algorithms."""
def learn(self, total_timesteps, callback=None, log_interval=4,
tb_log_name="run", reset_num_timesteps=True, progress_bar=False): ...
def predict(self, observation, state=None, episode_start=None, deterministic=False): ...
def save(self, path): ...
@classmethod
def load(cls, path, env=None, device="auto", **kwargs): ...
class BasePolicy:
"""Base policy class for all neural network policies."""
def forward(self, obs, deterministic=False): ...
def predict(self, observation, state=None, episode_start=None, deterministic=False): ...
class RolloutBuffer:
"""Buffer for on-policy algorithms."""
def add(self, obs, actions, rewards, episode_starts, values, log_probs): ...
def get(self, batch_size=None): ...
class ReplayBuffer:
"""Experience replay buffer for off-policy algorithms."""
def add(self, obs, next_obs, actions, rewards, dones, infos): ...
def sample(self, batch_size, env=None): ...Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks.
class DummyVecEnv:
"""Sequential vectorized environment."""
def __init__(self, env_fns): ...
def reset(self): ...
def step(self, actions): ...
class SubprocVecEnv:
"""Multiprocessing vectorized environment."""
def __init__(self, env_fns, start_method=None): ...
class VecNormalize:
"""Normalize observations and rewards."""
def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8): ...
class VecFrameStack:
"""Stack frames for recurrent policies."""
def __init__(self, venv, n_stack, channels_order="last"): ...Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes.
class EvalCallback:
"""Evaluate model during training."""
def __init__(self, eval_env, callback_on_new_best=None, n_eval_episodes=5,
eval_freq=10000, log_path=None, best_model_save_path=None,
deterministic=True, render=False, verbose=1): ...
class CheckpointCallback:
"""Save model at regular intervals."""
def __init__(self, save_freq, save_path, name_prefix="rl_model",
save_replay_buffer=False, save_vecnormalize=False, verbose=0): ...
class NormalActionNoise:
"""Gaussian action noise for exploration."""
def __init__(self, mean, sigma): ...
def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
render=False, callback=None, reward_threshold=None,
return_episode_rewards=False, warn=True, verbose=1): ...Implementation of Hindsight Experience Replay for goal-conditioned reinforcement learning, enabling learning from failed attempts by treating them as successful attempts toward different goals.
class HerReplayBuffer:
"""Replay buffer with Hindsight Experience Replay."""
def __init__(self, buffer_size, observation_space, action_space, device="auto",
n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True,
n_sampled_goal=4, goal_selection_strategy="future", wrapped_env=None): ...
class GoalSelectionStrategy:
"""Enumeration of goal selection strategies."""
FUTURE = "future"
FINAL = "final"
EPISODE = "episode"
RANDOM = "random"Utilities for retrieving system and environment information for debugging and reproducibility.
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
Parameters:
- print_info: Whether to print or not those infos
Returns:
Tuple containing dictionary with version info and formatted string
"""from typing import Union, Optional, Callable, Dict, Any, List, Tuple
import numpy as np
import torch
import gymnasium as gym
# Environment types
GymEnv = Union[gym.Env, gym.Wrapper]
VecEnv = Union[DummyVecEnv, SubprocVecEnv]
# Policy types
Schedule = Callable[[float], float]
MaybeCallback = Union[None, Callable, List[Callable], "BaseCallback"]
# Algorithm-specific policy types
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.sac.policies import SACPolicy
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.dqn.policies import DQNPolicy
# Buffer types
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
# Noise types
from stable_baselines3.common.noise import ActionNoise
# Observation and action types
PyTorchObs = Union[torch.Tensor, Dict[str, torch.Tensor]]
TensorDict = Dict[str, torch.Tensor]
# Training frequency specification
TrainFreq = Union[int, Tuple[int, str]]