Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options. Each algorithm is optimized for specific types of environments and learning scenarios.
On-policy algorithm that optimizes a clipped surrogate objective to ensure stable policy updates. Suitable for both continuous and discrete action spaces with excellent sample efficiency and stability.
class PPO(OnPolicyAlgorithm):
"""
Proximal Policy Optimization algorithm.
Args:
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
env: Environment or environment ID
learning_rate: Learning rate, can be a function of remaining progress
n_steps: Number of steps to run for each environment per update
batch_size: Minibatch size
n_epochs: Number of epochs when optimizing the surrogate loss
gamma: Discount factor
gae_lambda: Factor for trade-off of bias vs variance for GAE
clip_range: Clipping parameter for PPO surrogate objective
clip_range_vf: Clipping parameter for value function
normalize_advantage: Whether to normalize advantages
ent_coef: Entropy coefficient for exploration
vf_coef: Value function coefficient for loss calculation
max_grad_norm: Maximum value for gradient clipping
use_sde: Whether to use State Dependent Exploration
sde_sample_freq: Sample frequency for SDE
rollout_buffer_class: Rollout buffer class to use (None for default)
rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
target_kl: Limit KL divergence between updates
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
policy_kwargs: Additional arguments for policy construction
verbose: Verbosity level (0: no output, 1: info, 2: debug)
seed: Seed for random number generator
device: PyTorch device placement ("auto", "cpu", "cuda")
_init_setup_model: Whether to build network at creation
"""
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 2048,
batch_size: int = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Optional[Union[float, Schedule]] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...On-policy algorithm that combines value-based and policy-based methods. Synchronous version of A3C with simpler implementation and often better performance than async counterparts.
class A2C(OnPolicyAlgorithm):
"""
Advantage Actor-Critic algorithm.
Args:
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
env: Environment or environment ID
learning_rate: Learning rate, can be a function of remaining progress
n_steps: Number of steps to run for each environment per update
gamma: Discount factor
gae_lambda: Factor for trade-off of bias vs variance for GAE
ent_coef: Entropy coefficient for exploration
vf_coef: Value function coefficient
max_grad_norm: Maximum value for gradient clipping
rms_prop_eps: RMSprop optimizer epsilon
use_rms_prop: Whether to use RMSprop optimizer (vs Adam)
use_sde: Whether to use State Dependent Exploration
sde_sample_freq: Sample frequency for SDE
rollout_buffer_class: Rollout buffer class to use (None for default)
rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
normalize_advantage: Whether to normalize advantages
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
policy_kwargs: Additional arguments for policy construction
verbose: Verbosity level
seed: Seed for random number generator
device: PyTorch device placement
_init_setup_model: Whether to build network at creation
"""
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 7e-4,
n_steps: int = 5,
gamma: float = 0.99,
gae_lambda: float = 1.0,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rms_prop_eps: float = 1e-5,
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...Off-policy algorithm that incorporates entropy regularization to encourage exploration. Particularly effective for continuous control tasks with excellent sample efficiency and stability.
class SAC(OffPolicyAlgorithm):
"""
Soft Actor-Critic algorithm.
Args:
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
env: Environment or environment ID
learning_rate: Learning rate, can be a function of remaining progress
buffer_size: Size of replay buffer
learning_starts: Steps before learning starts
batch_size: Minibatch size for training
tau: Soft update coefficient for target networks
gamma: Discount factor
train_freq: Update policy every n steps or episodes
gradient_steps: Gradient steps per update
action_noise: Action noise for exploration
replay_buffer_class: Replay buffer class
replay_buffer_kwargs: Additional replay buffer arguments
optimize_memory_usage: Enable memory optimizations
n_steps: Number of steps for n-step return calculation
ent_coef: Entropy regularization coefficient
target_update_interval: Update target network every n gradient steps
target_entropy: Target entropy for automatic entropy tuning
use_sde: Whether to use State Dependent Exploration
sde_sample_freq: Sample frequency for SDE
use_sde_at_warmup: Use SDE instead of uniform sampling during warmup
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
policy_kwargs: Additional arguments for policy construction
verbose: Verbosity level
seed: Seed for random number generator
device: PyTorch device placement
_init_setup_model: Whether to build network at creation
"""
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000,
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...Off-policy algorithm that addresses the overestimation bias in DDPG through twin critics and delayed policy updates. Excellent for continuous control with improved stability over DDPG.
class TD3(OffPolicyAlgorithm):
"""
Twin Delayed Deep Deterministic Policy Gradient algorithm.
Args:
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
env: Environment or environment ID
learning_rate: Learning rate, can be a function of remaining progress
buffer_size: Size of replay buffer
learning_starts: Steps before learning starts
batch_size: Minibatch size for training
tau: Soft update coefficient for target networks
gamma: Discount factor
train_freq: Update policy every n steps or episodes
gradient_steps: Gradient steps per update
action_noise: Action noise for exploration
replay_buffer_class: Replay buffer class
replay_buffer_kwargs: Additional replay buffer arguments
optimize_memory_usage: Enable memory optimizations
n_steps: Number of steps for n-step return calculation
policy_delay: Policy update delay (TD3 specific)
target_policy_noise: Noise added to target policy
target_noise_clip: Range to clip target policy noise
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
policy_kwargs: Additional arguments for policy construction
verbose: Verbosity level
seed: Seed for random number generator
device: PyTorch device placement
_init_setup_model: Whether to build network at creation
"""
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000,
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
policy_delay: int = 2,
target_policy_noise: float = 0.2,
target_noise_clip: float = 0.5,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...Off-policy algorithm for continuous control that combines DQN with policy gradients. Implemented as a special case of TD3 without the twin critics and delayed updates.
class DDPG(TD3):
"""
Deep Deterministic Policy Gradient algorithm.
Args:
Same as TD3 but with different default values:
- policy_delay: 1 (immediate policy updates)
- target_policy_noise: 0.0 (no target policy noise)
- target_noise_clip: 0.0 (no noise clipping)
"""
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000,
learning_starts: int = 100,
batch_size: int = 100,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...Off-policy value-based algorithm for discrete action spaces. Uses experience replay and target networks to stabilize learning of Q-values.
class DQN(OffPolicyAlgorithm):
"""
Deep Q-Network algorithm.
Args:
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
env: Environment or environment ID
learning_rate: Learning rate, can be a function of remaining progress
buffer_size: Size of replay buffer
learning_starts: Steps before learning starts
batch_size: Minibatch size for training
tau: Soft update coefficient (1.0 = hard update)
gamma: Discount factor
train_freq: Update policy every n steps
gradient_steps: Gradient steps per update
replay_buffer_class: Replay buffer class
replay_buffer_kwargs: Additional replay buffer arguments
optimize_memory_usage: Enable memory optimizations
n_steps: Number of steps for n-step return calculation
target_update_interval: Hard update interval for target network
exploration_fraction: Fraction of training for exploration decay
exploration_initial_eps: Initial exploration probability
exploration_final_eps: Final exploration probability
max_grad_norm: Maximum gradient norm
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
policy_kwargs: Additional arguments for policy construction
verbose: Verbosity level
seed: Seed for random number generator
device: PyTorch device placement
_init_setup_model: Whether to build network at creation
"""
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000,
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
): ...All algorithms support three standard policy architectures that can be specified by string or class:
# Multi-layer perceptron policy for vector observations
MlpPolicy = "MlpPolicy"
# Convolutional neural network policy for image observations
CnnPolicy = "CnnPolicy"
# Multi-input policy for dictionary observations
MultiInputPolicy = "MultiInputPolicy"import gymnasium as gym
from stable_baselines3 import PPO
# Create environment and agent
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=25000)
# Use the trained agent
obs, info = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()from stable_baselines3 import SAC
# Custom policy architecture
policy_kwargs = dict(
net_arch=dict(pi=[400, 300], qf=[400, 300]),
activation_fn=torch.nn.ReLU,
)
model = SAC(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_rate=3e-4,
buffer_size=1000000,
batch_size=256,
verbose=1
)import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
# Create action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
mean=np.zeros(n_actions),
sigma=0.1 * np.ones(n_actions)
)
model = TD3(
"MlpPolicy",
env,
action_noise=action_noise,
verbose=1
)
model.learn(total_timesteps=100000)from typing import Union, Optional, Type, Callable, Dict, Any, Tuple
import numpy as np
import torch
import gymnasium as gym
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBufferInstall with Tessl CLI
npx tessl i tessl/pypi-stable-baselines3