CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Overview
Eval results
Files

common-framework.mddocs/

Common Framework

Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the Stable Baselines3 library. This framework promotes code reuse and ensures uniform interfaces across different algorithm implementations.

Capabilities

Base Algorithm Classes

Abstract base classes that define the core functionality shared by all reinforcement learning algorithms, including training loops, model management, and prediction interfaces.

class BaseAlgorithm:
    """
    Abstract base class for all RL algorithms.
    
    Args:
        policy: Policy class or string identifier
        env: Environment or environment ID
        learning_rate: Learning rate for optimization
        policy_kwargs: Additional arguments for policy construction
        stats_window_size: Window size for rollout logging averaging
        tensorboard_log: Path to TensorBoard log directory
        verbose: Verbosity level (0: no output, 1: info, 2: debug)
        device: PyTorch device placement ("auto", "cpu", "cuda")
        support_multi_env: Whether algorithm supports multiple environments
        monitor_wrapper: Whether to wrap environment with Monitor
        seed: Random seed for reproducibility
        use_sde: Whether to use State Dependent Exploration
        sde_sample_freq: Sample frequency for SDE
        supported_action_spaces: List of supported action spaces
    """
    def __init__(
        self,
        policy: Union[str, Type[BasePolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule],
        policy_kwargs: Optional[Dict[str, Any]] = None,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        verbose: int = 0,
        device: Union[torch.device, str] = "auto",
        support_multi_env: bool = False,
        monitor_wrapper: bool = True,
        seed: Optional[int] = None,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        supported_action_spaces: Optional[Tuple[Type[gym.Space], ...]] = None,
    ): ...

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "run",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> "BaseAlgorithm":
        """
        Train the agent for total_timesteps.
        
        Args:
            total_timesteps: Total number of timesteps to train
            callback: Callback(s) called during training
            log_interval: Log interval for training metrics
            tb_log_name: TensorBoard log name
            reset_num_timesteps: Reset timestep counter
            progress_bar: Display progress bar
            
        Returns:
            Trained algorithm instance
        """

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Get action from observation.
        
        Args:
            observation: Input observation
            state: Hidden state for recurrent policies
            episode_start: Start of episode mask
            deterministic: Use deterministic actions
            
        Returns:
            Tuple of (action, next_state)
        """

    def save(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
        """Save model to file path."""

    @classmethod
    def load(
        cls,
        path: Union[str, pathlib.Path, io.BufferedIOBase],
        env: Optional[GymEnv] = None,
        device: Union[torch.device, str] = "auto",
        custom_objects: Optional[Dict[str, Any]] = None,
        print_system_info: bool = False,
        force_reset: bool = True,
        **kwargs,
    ) -> "BaseAlgorithm":
        """Load model from file path."""

    def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
        """Set new environment for the algorithm."""

    def get_env(self) -> Optional[VecEnv]:
        """Get current environment."""

    def set_random_seed(self, seed: Optional[int] = None) -> None:
        """Set random seed for reproducibility."""

class OnPolicyAlgorithm(BaseAlgorithm):
    """
    Base class for on-policy algorithms (A2C, PPO).
    
    Additional Args:
        n_steps: Number of steps per environment per update
        gamma: Discount factor
        gae_lambda: GAE lambda parameter
        ent_coef: Entropy coefficient
        vf_coef: Value function coefficient
        max_grad_norm: Maximum gradient norm for clipping
    """
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule],
        n_steps: int,
        gamma: float,
        gae_lambda: float,
        ent_coef: float,
        vf_coef: float,
        max_grad_norm: float,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ): ...

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        rollout_buffer: RolloutBuffer,
        n_rollout_steps: int,
    ) -> bool:
        """Collect rollout data from environment."""

class OffPolicyAlgorithm(BaseAlgorithm):
    """
    Base class for off-policy algorithms (SAC, TD3, DDPG, DQN).
    
    Additional Args:
        buffer_size: Replay buffer size
        learning_starts: Steps before learning starts
        batch_size: Minibatch size for training
        tau: Soft update coefficient for target networks
        train_freq: Training frequency
        gradient_steps: Gradient steps per training
        action_noise: Action noise for exploration
        replay_buffer_class: Replay buffer class
        replay_buffer_kwargs: Additional buffer arguments
        optimize_memory_usage: Enable memory optimizations
    """
    def __init__(
        self,
        policy: Union[str, Type[BasePolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule],
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = (1, "step"),
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        **kwargs,
    ): ...

    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Sample action with exploration noise."""

Policy Base Classes

Neural network architectures that define how observations are processed and actions are selected, supporting different observation spaces and algorithm requirements.

class BaseModel(torch.nn.Module):
    """
    Base class for all neural network models.
    
    Args:
        observation_space: Observation space
        action_space: Action space
        lr_schedule: Learning rate schedule
        use_sde: Whether to use State Dependent Exploration
        log_std_init: Initial log standard deviation
        full_std: Use full covariance matrix for SDE
        sde_net_arch: Network architecture for SDE
        use_expln: Use exponential activation for variance
        squash_output: Squash output with tanh
        features_extractor_class: Feature extractor class
        features_extractor_kwargs: Feature extractor arguments
        share_features_extractor: Share feature extractor between actor/critic
        normalize_images: Normalize image observations
        optimizer_class: Optimizer class
        optimizer_kwargs: Optimizer arguments
    """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        sde_net_arch: Optional[List[int]] = None,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ): ...

    def forward(self, *args, **kwargs) -> torch.Tensor:
        """Forward pass through the network."""

class BasePolicy(BaseModel):
    """
    Base policy class for all algorithms.
    
    Args:
        observation_space: Observation space
        action_space: Action space
        lr_schedule: Learning rate schedule
        use_sde: Whether to use State Dependent Exploration
        **kwargs: Additional arguments passed to BaseModel
    """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        use_sde: bool = False,
        **kwargs,
    ): ...

    def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        """Get action from observation."""

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """Get action and state from observation."""

    def _predict(
        self, observation: torch.Tensor, deterministic: bool = False
    ) -> torch.Tensor:
        """Internal prediction method."""

class ActorCriticPolicy(BasePolicy):
    """
    Policy with both actor and critic networks for on-policy algorithms.
    
    Args:
        observation_space: Observation space
        action_space: Action space
        lr_schedule: Learning rate schedule
        net_arch: Network architecture specification
        activation_fn: Activation function
        ortho_init: Use orthogonal initialization
        use_sde: Whether to use State Dependent Exploration
        log_std_init: Initial log standard deviation
        full_std: Use full covariance matrix for SDE
        sde_net_arch: Network architecture for SDE
        use_expln: Use exponential activation for variance
        squash_output: Squash output with tanh
        features_extractor_class: Feature extractor class
        features_extractor_kwargs: Feature extractor arguments
        share_features_extractor: Share feature extractor between actor/critic
        normalize_images: Normalize image observations
        optimizer_class: Optimizer class
        optimizer_kwargs: Optimizer arguments
    """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[torch.nn.Module] = torch.nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        sde_net_arch: Optional[List[int]] = None,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ): ...

    def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        """Forward pass through actor network."""

    def evaluate_actions(
        self, obs: torch.Tensor, actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Evaluate actions for training."""

    def get_distribution(self, obs: torch.Tensor) -> Distribution:
        """Get action distribution from observation."""

    def predict_values(self, obs: torch.Tensor) -> torch.Tensor:
        """Predict state values using critic network."""

Common Policy Aliases

All algorithms provide convenient aliases for their policy classes to simplify usage:

# Standard policy aliases (available in each algorithm module)
MlpPolicy = ActorCriticPolicy  # For algorithms like A2C, PPO
CnnPolicy = ActorCriticCnnPolicy  # For image-based observations  
MultiInputPolicy = MultiInputActorCriticPolicy  # For dict observations

# Import examples:
from stable_baselines3.ppo import MlpPolicy as PPOMlpPolicy
from stable_baselines3.a2c import CnnPolicy as A2CCnnPolicy
from stable_baselines3.sac import MlpPolicy as SACMlpPolicy

Experience Buffers

Storage mechanisms for training data that enable different sampling strategies and memory management approaches for various algorithm types.

class BaseBuffer:
    """
    Abstract base class for all experience buffers.
    
    Args:
        buffer_size: Maximum buffer capacity
        observation_space: Observation space
        action_space: Action space
        device: PyTorch device placement
        n_envs: Number of parallel environments
    """
    def __init__(
        self,
        buffer_size: int,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        device: Union[torch.device, str] = "auto",
        n_envs: int = 1,
    ): ...

    def add(self, *args, **kwargs) -> None:
        """Add experience to buffer."""

    def get(self, *args, **kwargs) -> Any:
        """Sample experience from buffer."""

    def reset(self) -> None:
        """Reset buffer to empty state."""

    def size(self) -> int:
        """Current buffer size."""

class RolloutBuffer(BaseBuffer):
    """
    Buffer for on-policy algorithms that stores rollout trajectories.
    
    Args:
        buffer_size: Buffer capacity (typically n_steps * n_envs)
        observation_space: Observation space
        action_space: Action space
        device: PyTorch device placement
        gae_lambda: GAE lambda parameter
        gamma: Discount factor
        n_envs: Number of parallel environments
    """
    def __init__(
        self,
        buffer_size: int,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        device: Union[torch.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
    ): ...

    def add(
        self,
        obs: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        episode_starts: np.ndarray,
        values: torch.Tensor,
        log_probs: torch.Tensor,
    ) -> None:
        """
        Add rollout data to buffer.
        
        Args:
            obs: Observations
            actions: Actions taken
            rewards: Rewards received
            episode_starts: Episode start flags
            values: State value estimates
            log_probs: Action log probabilities
        """

    def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
        """
        Sample batches from buffer.
        
        Args:
            batch_size: Size of batches to sample
            
        Yields:
            Batches of rollout data
        """

    def compute_returns_and_advantage(
        self, last_values: torch.Tensor, dones: np.ndarray
    ) -> None:
        """
        Compute returns and advantages using GAE.
        
        Args:
            last_values: Value estimates for final states
            dones: Episode termination flags
        """

class ReplayBuffer(BaseBuffer):
    """
    Experience replay buffer for off-policy algorithms.
    
    Args:
        buffer_size: Maximum buffer capacity
        observation_space: Observation space
        action_space: Action space
        device: PyTorch device placement
        n_envs: Number of parallel environments
        optimize_memory_usage: Enable memory optimizations
        handle_timeout_termination: Handle timeout terminations properly
    """
    def __init__(
        self,
        buffer_size: int,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        device: Union[torch.device, str] = "auto",
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        handle_timeout_termination: bool = True,
    ): ...

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        dones: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:
        """
        Add transition to replay buffer.
        
        Args:
            obs: Current observations
            next_obs: Next observations
            actions: Actions taken
            rewards: Rewards received
            dones: Episode termination flags
            infos: Additional information
        """

    def sample(self, batch_size: int, env: Optional[VecEnv] = None) -> ReplayBufferSamples:
        """
        Sample batch of transitions.
        
        Args:
            batch_size: Number of transitions to sample
            env: Environment for normalization
            
        Returns:
            Batch of experience samples
        """

class DictRolloutBuffer(RolloutBuffer):
    """Rollout buffer for dictionary observations."""

class DictReplayBuffer(ReplayBuffer):
    """Replay buffer for dictionary observations."""

Usage Examples

Custom Policy Architecture

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
import torch.nn as nn

# Define custom network architecture
policy_kwargs = dict(
    net_arch=[dict(pi=[128, 128], vf=[128, 128])],
    activation_fn=nn.ReLU,
    ortho_init=True,
)

model = PPO(
    "MlpPolicy",
    env,
    policy_kwargs=policy_kwargs,
    verbose=1
)

Custom Buffer Configuration

from stable_baselines3 import SAC
from stable_baselines3.common.buffers import ReplayBuffer

# Custom replay buffer settings
replay_buffer_kwargs = dict(
    optimize_memory_usage=True,
    handle_timeout_termination=True,
)

model = SAC(
    "MlpPolicy",
    env,
    buffer_size=500000,
    replay_buffer_kwargs=replay_buffer_kwargs,
    verbose=1
)

Accessing Buffer Data

# Access replay buffer after training
replay_buffer = model.replay_buffer

# Sample transitions for analysis
batch = replay_buffer.sample(batch_size=256)
observations = batch.observations
actions = batch.actions
rewards = batch.rewards

Types

from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
import numpy as np
import torch
import gymnasium as gym
import pathlib
import io
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback, PyTorchObs, TensorDict
from stable_baselines3.common.policies import BasePolicy, ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.vec_env import VecEnv

Install with Tessl CLI

npx tessl i tessl/pypi-stable-baselines3

docs

algorithms.md

common-framework.md

environments.md

her.md

index.md

training-utilities.md

tile.json