CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Overview
Eval results
Files

algorithms.mddocs/

Core Algorithms

Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options. Each algorithm is optimized for specific types of environments and learning scenarios.

Capabilities

Proximal Policy Optimization (PPO)

On-policy algorithm that optimizes a clipped surrogate objective to ensure stable policy updates. Suitable for both continuous and discrete action spaces with excellent sample efficiency and stability.

class PPO(OnPolicyAlgorithm):
    """
    Proximal Policy Optimization algorithm.
    
    Args:
        policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
        env: Environment or environment ID
        learning_rate: Learning rate, can be a function of remaining progress
        n_steps: Number of steps to run for each environment per update
        batch_size: Minibatch size
        n_epochs: Number of epochs when optimizing the surrogate loss
        gamma: Discount factor
        gae_lambda: Factor for trade-off of bias vs variance for GAE
        clip_range: Clipping parameter for PPO surrogate objective
        clip_range_vf: Clipping parameter for value function
        normalize_advantage: Whether to normalize advantages
        ent_coef: Entropy coefficient for exploration
        vf_coef: Value function coefficient for loss calculation
        max_grad_norm: Maximum value for gradient clipping
        use_sde: Whether to use State Dependent Exploration
        sde_sample_freq: Sample frequency for SDE
        rollout_buffer_class: Rollout buffer class to use (None for default)
        rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
        target_kl: Limit KL divergence between updates
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        policy_kwargs: Additional arguments for policy construction
        verbose: Verbosity level (0: no output, 1: info, 2: debug)
        seed: Seed for random number generator
        device: PyTorch device placement ("auto", "cpu", "cuda")
        _init_setup_model: Whether to build network at creation
    """
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048,
        batch_size: int = 64,
        n_epochs: int = 10,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range: Union[float, Schedule] = 0.2,
        clip_range_vf: Optional[Union[float, Schedule]] = None,
        normalize_advantage: bool = True,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
        target_kl: Optional[float] = None,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Advantage Actor-Critic (A2C)

On-policy algorithm that combines value-based and policy-based methods. Synchronous version of A3C with simpler implementation and often better performance than async counterparts.

class A2C(OnPolicyAlgorithm):
    """
    Advantage Actor-Critic algorithm.
    
    Args:
        policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
        env: Environment or environment ID
        learning_rate: Learning rate, can be a function of remaining progress
        n_steps: Number of steps to run for each environment per update
        gamma: Discount factor
        gae_lambda: Factor for trade-off of bias vs variance for GAE
        ent_coef: Entropy coefficient for exploration
        vf_coef: Value function coefficient
        max_grad_norm: Maximum value for gradient clipping
        rms_prop_eps: RMSprop optimizer epsilon
        use_rms_prop: Whether to use RMSprop optimizer (vs Adam)
        use_sde: Whether to use State Dependent Exploration
        sde_sample_freq: Sample frequency for SDE
        rollout_buffer_class: Rollout buffer class to use (None for default)
        rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
        normalize_advantage: Whether to normalize advantages
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        policy_kwargs: Additional arguments for policy construction
        verbose: Verbosity level
        seed: Seed for random number generator
        device: PyTorch device placement
        _init_setup_model: Whether to build network at creation
    """
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 7e-4,
        n_steps: int = 5,
        gamma: float = 0.99,
        gae_lambda: float = 1.0,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        rms_prop_eps: float = 1e-5,
        use_rms_prop: bool = True,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
        normalize_advantage: bool = False,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Soft Actor-Critic (SAC)

Off-policy algorithm that incorporates entropy regularization to encourage exploration. Particularly effective for continuous control tasks with excellent sample efficiency and stability.

class SAC(OffPolicyAlgorithm):
    """
    Soft Actor-Critic algorithm.
    
    Args:
        policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
        env: Environment or environment ID
        learning_rate: Learning rate, can be a function of remaining progress
        buffer_size: Size of replay buffer
        learning_starts: Steps before learning starts
        batch_size: Minibatch size for training
        tau: Soft update coefficient for target networks
        gamma: Discount factor
        train_freq: Update policy every n steps or episodes
        gradient_steps: Gradient steps per update
        action_noise: Action noise for exploration
        replay_buffer_class: Replay buffer class
        replay_buffer_kwargs: Additional replay buffer arguments
        optimize_memory_usage: Enable memory optimizations
        n_steps: Number of steps for n-step return calculation
        ent_coef: Entropy regularization coefficient
        target_update_interval: Update target network every n gradient steps
        target_entropy: Target entropy for automatic entropy tuning
        use_sde: Whether to use State Dependent Exploration
        sde_sample_freq: Sample frequency for SDE
        use_sde_at_warmup: Use SDE instead of uniform sampling during warmup
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        policy_kwargs: Additional arguments for policy construction
        verbose: Verbosity level
        seed: Seed for random number generator
        device: PyTorch device placement
        _init_setup_model: Whether to build network at creation
    """
    def __init__(
        self,
        policy: Union[str, Type[SACPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        n_steps: int = 1,
        ent_coef: Union[str, float] = "auto",
        target_update_interval: int = 1,
        target_entropy: Union[str, float] = "auto",
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        use_sde_at_warmup: bool = False,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Twin Delayed Deep Deterministic Policy Gradient (TD3)

Off-policy algorithm that addresses the overestimation bias in DDPG through twin critics and delayed policy updates. Excellent for continuous control with improved stability over DDPG.

class TD3(OffPolicyAlgorithm):
    """
    Twin Delayed Deep Deterministic Policy Gradient algorithm.
    
    Args:
        policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
        env: Environment or environment ID
        learning_rate: Learning rate, can be a function of remaining progress
        buffer_size: Size of replay buffer
        learning_starts: Steps before learning starts
        batch_size: Minibatch size for training
        tau: Soft update coefficient for target networks
        gamma: Discount factor
        train_freq: Update policy every n steps or episodes
        gradient_steps: Gradient steps per update
        action_noise: Action noise for exploration
        replay_buffer_class: Replay buffer class
        replay_buffer_kwargs: Additional replay buffer arguments
        optimize_memory_usage: Enable memory optimizations
        n_steps: Number of steps for n-step return calculation
        policy_delay: Policy update delay (TD3 specific)
        target_policy_noise: Noise added to target policy
        target_noise_clip: Range to clip target policy noise
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        policy_kwargs: Additional arguments for policy construction
        verbose: Verbosity level
        seed: Seed for random number generator
        device: PyTorch device placement
        _init_setup_model: Whether to build network at creation
    """
    def __init__(
        self,
        policy: Union[str, Type[TD3Policy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-3,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        n_steps: int = 1,
        policy_delay: int = 2,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Deep Deterministic Policy Gradient (DDPG)

Off-policy algorithm for continuous control that combines DQN with policy gradients. Implemented as a special case of TD3 without the twin critics and delayed updates.

class DDPG(TD3):
    """
    Deep Deterministic Policy Gradient algorithm.
    
    Args:
        Same as TD3 but with different default values:
        - policy_delay: 1 (immediate policy updates)
        - target_policy_noise: 0.0 (no target policy noise)
        - target_noise_clip: 0.0 (no noise clipping)
    """
    def __init__(
        self,
        policy: Union[str, Type[TD3Policy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 100,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
        gradient_steps: int = -1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Deep Q-Network (DQN)

Off-policy value-based algorithm for discrete action spaces. Uses experience replay and target networks to stabilize learning of Q-values.

class DQN(OffPolicyAlgorithm):
    """
    Deep Q-Network algorithm.
    
    Args:
        policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
        env: Environment or environment ID
        learning_rate: Learning rate, can be a function of remaining progress
        buffer_size: Size of replay buffer
        learning_starts: Steps before learning starts
        batch_size: Minibatch size for training
        tau: Soft update coefficient (1.0 = hard update)
        gamma: Discount factor
        train_freq: Update policy every n steps
        gradient_steps: Gradient steps per update
        replay_buffer_class: Replay buffer class
        replay_buffer_kwargs: Additional replay buffer arguments
        optimize_memory_usage: Enable memory optimizations
        n_steps: Number of steps for n-step return calculation
        target_update_interval: Hard update interval for target network
        exploration_fraction: Fraction of training for exploration decay
        exploration_initial_eps: Initial exploration probability
        exploration_final_eps: Final exploration probability
        max_grad_norm: Maximum gradient norm
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        policy_kwargs: Additional arguments for policy construction
        verbose: Verbosity level
        seed: Seed for random number generator
        device: PyTorch device placement
        _init_setup_model: Whether to build network at creation  
    """
    def __init__(
        self,
        policy: Union[str, Type[DQNPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 32,
        tau: float = 1.0,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 4,
        gradient_steps: int = 1,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        n_steps: int = 1,
        target_update_interval: int = 10000,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.05,
        max_grad_norm: float = 10,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[torch.device, str] = "auto",
        _init_setup_model: bool = True,
    ): ...

Policy Types

All algorithms support three standard policy architectures that can be specified by string or class:

# Multi-layer perceptron policy for vector observations
MlpPolicy = "MlpPolicy"

# Convolutional neural network policy for image observations  
CnnPolicy = "CnnPolicy"

# Multi-input policy for dictionary observations
MultiInputPolicy = "MultiInputPolicy"

Usage Examples

Basic Algorithm Training

import gymnasium as gym
from stable_baselines3 import PPO

# Create environment and agent
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)

# Train the agent
model.learn(total_timesteps=25000)

# Use the trained agent
obs, info = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, info = env.reset()

Custom Policy Networks

from stable_baselines3 import SAC

# Custom policy architecture
policy_kwargs = dict(
    net_arch=dict(pi=[400, 300], qf=[400, 300]),
    activation_fn=torch.nn.ReLU,
)

model = SAC(
    "MlpPolicy", 
    env, 
    policy_kwargs=policy_kwargs,
    learning_rate=3e-4,
    buffer_size=1000000,
    batch_size=256,
    verbose=1
)

Continuous Control with Noise

import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise

# Create action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
    mean=np.zeros(n_actions), 
    sigma=0.1 * np.ones(n_actions)
)

model = TD3(
    "MlpPolicy",
    env,
    action_noise=action_noise,
    verbose=1
)

model.learn(total_timesteps=100000)

Types

from typing import Union, Optional, Type, Callable, Dict, Any, Tuple
import numpy as np
import torch
import gymnasium as gym
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer

Install with Tessl CLI

npx tessl i tessl/pypi-stable-baselines3

docs

algorithms.md

common-framework.md

environments.md

her.md

index.md

training-utilities.md

tile.json