Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the Stable Baselines3 library. This framework promotes code reuse and ensures uniform interfaces across different algorithm implementations.
Abstract base classes that define the core functionality shared by all reinforcement learning algorithms, including training loops, model management, and prediction interfaces.
class BaseAlgorithm:
"""
Abstract base class for all RL algorithms.
Args:
policy: Policy class or string identifier
env: Environment or environment ID
learning_rate: Learning rate for optimization
policy_kwargs: Additional arguments for policy construction
stats_window_size: Window size for rollout logging averaging
tensorboard_log: Path to TensorBoard log directory
verbose: Verbosity level (0: no output, 1: info, 2: debug)
device: PyTorch device placement ("auto", "cpu", "cuda")
support_multi_env: Whether algorithm supports multiple environments
monitor_wrapper: Whether to wrap environment with Monitor
seed: Random seed for reproducibility
use_sde: Whether to use State Dependent Exploration
sde_sample_freq: Sample frequency for SDE
supported_action_spaces: List of supported action spaces
"""
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[torch.device, str] = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[Type[gym.Space], ...]] = None,
): ...
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> "BaseAlgorithm":
"""
Train the agent for total_timesteps.
Args:
total_timesteps: Total number of timesteps to train
callback: Callback(s) called during training
log_interval: Log interval for training metrics
tb_log_name: TensorBoard log name
reset_num_timesteps: Reset timestep counter
progress_bar: Display progress bar
Returns:
Trained algorithm instance
"""
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get action from observation.
Args:
observation: Input observation
state: Hidden state for recurrent policies
episode_start: Start of episode mask
deterministic: Use deterministic actions
Returns:
Tuple of (action, next_state)
"""
def save(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""Save model to file path."""
@classmethod
def load(
cls,
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[torch.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> "BaseAlgorithm":
"""Load model from file path."""
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
"""Set new environment for the algorithm."""
def get_env(self) -> Optional[VecEnv]:
"""Get current environment."""
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""Set random seed for reproducibility."""
class OnPolicyAlgorithm(BaseAlgorithm):
"""
Base class for on-policy algorithms (A2C, PPO).
Additional Args:
n_steps: Number of steps per environment per update
gamma: Discount factor
gae_lambda: GAE lambda parameter
ent_coef: Entropy coefficient
vf_coef: Value function coefficient
max_grad_norm: Maximum gradient norm for clipping
"""
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
): ...
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""Collect rollout data from environment."""
class OffPolicyAlgorithm(BaseAlgorithm):
"""
Base class for off-policy algorithms (SAC, TD3, DDPG, DQN).
Additional Args:
buffer_size: Replay buffer size
learning_starts: Steps before learning starts
batch_size: Minibatch size for training
tau: Soft update coefficient for target networks
train_freq: Training frequency
gradient_steps: Gradient steps per training
action_noise: Action noise for exploration
replay_buffer_class: Replay buffer class
replay_buffer_kwargs: Additional buffer arguments
optimize_memory_usage: Enable memory optimizations
"""
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000,
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
**kwargs,
): ...
def _sample_action(
self,
learning_starts: int,
action_noise: Optional[ActionNoise] = None,
n_envs: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""Sample action with exploration noise."""Neural network architectures that define how observations are processed and actions are selected, supporting different observation spaces and algorithm requirements.
class BaseModel(torch.nn.Module):
"""
Base class for all neural network models.
Args:
observation_space: Observation space
action_space: Action space
lr_schedule: Learning rate schedule
use_sde: Whether to use State Dependent Exploration
log_std_init: Initial log standard deviation
full_std: Use full covariance matrix for SDE
sde_net_arch: Network architecture for SDE
use_expln: Use exponential activation for variance
squash_output: Squash output with tanh
features_extractor_class: Feature extractor class
features_extractor_kwargs: Feature extractor arguments
share_features_extractor: Share feature extractor between actor/critic
normalize_images: Normalize image observations
optimizer_class: Optimizer class
optimizer_kwargs: Optimizer arguments
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ...
def forward(self, *args, **kwargs) -> torch.Tensor:
"""Forward pass through the network."""
class BasePolicy(BaseModel):
"""
Base policy class for all algorithms.
Args:
observation_space: Observation space
action_space: Action space
lr_schedule: Learning rate schedule
use_sde: Whether to use State Dependent Exploration
**kwargs: Additional arguments passed to BaseModel
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
use_sde: bool = False,
**kwargs,
): ...
def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
"""Get action from observation."""
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""Get action and state from observation."""
def _predict(
self, observation: torch.Tensor, deterministic: bool = False
) -> torch.Tensor:
"""Internal prediction method."""
class ActorCriticPolicy(BasePolicy):
"""
Policy with both actor and critic networks for on-policy algorithms.
Args:
observation_space: Observation space
action_space: Action space
lr_schedule: Learning rate schedule
net_arch: Network architecture specification
activation_fn: Activation function
ortho_init: Use orthogonal initialization
use_sde: Whether to use State Dependent Exploration
log_std_init: Initial log standard deviation
full_std: Use full covariance matrix for SDE
sde_net_arch: Network architecture for SDE
use_expln: Use exponential activation for variance
squash_output: Squash output with tanh
features_extractor_class: Feature extractor class
features_extractor_kwargs: Feature extractor arguments
share_features_extractor: Share feature extractor between actor/critic
normalize_images: Normalize image observations
optimizer_class: Optimizer class
optimizer_kwargs: Optimizer arguments
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[torch.nn.Module] = torch.nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ...
def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
"""Forward pass through actor network."""
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Evaluate actions for training."""
def get_distribution(self, obs: torch.Tensor) -> Distribution:
"""Get action distribution from observation."""
def predict_values(self, obs: torch.Tensor) -> torch.Tensor:
"""Predict state values using critic network."""All algorithms provide convenient aliases for their policy classes to simplify usage:
# Standard policy aliases (available in each algorithm module)
MlpPolicy = ActorCriticPolicy # For algorithms like A2C, PPO
CnnPolicy = ActorCriticCnnPolicy # For image-based observations
MultiInputPolicy = MultiInputActorCriticPolicy # For dict observations
# Import examples:
from stable_baselines3.ppo import MlpPolicy as PPOMlpPolicy
from stable_baselines3.a2c import CnnPolicy as A2CCnnPolicy
from stable_baselines3.sac import MlpPolicy as SACMlpPolicyStorage mechanisms for training data that enable different sampling strategies and memory management approaches for various algorithm types.
class BaseBuffer:
"""
Abstract base class for all experience buffers.
Args:
buffer_size: Maximum buffer capacity
observation_space: Observation space
action_space: Action space
device: PyTorch device placement
n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[torch.device, str] = "auto",
n_envs: int = 1,
): ...
def add(self, *args, **kwargs) -> None:
"""Add experience to buffer."""
def get(self, *args, **kwargs) -> Any:
"""Sample experience from buffer."""
def reset(self) -> None:
"""Reset buffer to empty state."""
def size(self) -> int:
"""Current buffer size."""
class RolloutBuffer(BaseBuffer):
"""
Buffer for on-policy algorithms that stores rollout trajectories.
Args:
buffer_size: Buffer capacity (typically n_steps * n_envs)
observation_space: Observation space
action_space: Action space
device: PyTorch device placement
gae_lambda: GAE lambda parameter
gamma: Discount factor
n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[torch.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
): ...
def add(
self,
obs: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
episode_starts: np.ndarray,
values: torch.Tensor,
log_probs: torch.Tensor,
) -> None:
"""
Add rollout data to buffer.
Args:
obs: Observations
actions: Actions taken
rewards: Rewards received
episode_starts: Episode start flags
values: State value estimates
log_probs: Action log probabilities
"""
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
"""
Sample batches from buffer.
Args:
batch_size: Size of batches to sample
Yields:
Batches of rollout data
"""
def compute_returns_and_advantage(
self, last_values: torch.Tensor, dones: np.ndarray
) -> None:
"""
Compute returns and advantages using GAE.
Args:
last_values: Value estimates for final states
dones: Episode termination flags
"""
class ReplayBuffer(BaseBuffer):
"""
Experience replay buffer for off-policy algorithms.
Args:
buffer_size: Maximum buffer capacity
observation_space: Observation space
action_space: Action space
device: PyTorch device placement
n_envs: Number of parallel environments
optimize_memory_usage: Enable memory optimizations
handle_timeout_termination: Handle timeout terminations properly
"""
def __init__(
self,
buffer_size: int,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[torch.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
): ...
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
"""
Add transition to replay buffer.
Args:
obs: Current observations
next_obs: Next observations
actions: Actions taken
rewards: Rewards received
dones: Episode termination flags
infos: Additional information
"""
def sample(self, batch_size: int, env: Optional[VecEnv] = None) -> ReplayBufferSamples:
"""
Sample batch of transitions.
Args:
batch_size: Number of transitions to sample
env: Environment for normalization
Returns:
Batch of experience samples
"""
class DictRolloutBuffer(RolloutBuffer):
"""Rollout buffer for dictionary observations."""
class DictReplayBuffer(ReplayBuffer):
"""Replay buffer for dictionary observations."""from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
import torch.nn as nn
# Define custom network architecture
policy_kwargs = dict(
net_arch=[dict(pi=[128, 128], vf=[128, 128])],
activation_fn=nn.ReLU,
ortho_init=True,
)
model = PPO(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
verbose=1
)from stable_baselines3 import SAC
from stable_baselines3.common.buffers import ReplayBuffer
# Custom replay buffer settings
replay_buffer_kwargs = dict(
optimize_memory_usage=True,
handle_timeout_termination=True,
)
model = SAC(
"MlpPolicy",
env,
buffer_size=500000,
replay_buffer_kwargs=replay_buffer_kwargs,
verbose=1
)# Access replay buffer after training
replay_buffer = model.replay_buffer
# Sample transitions for analysis
batch = replay_buffer.sample(batch_size=256)
observations = batch.observations
actions = batch.actions
rewards = batch.rewardsfrom typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
import numpy as np
import torch
import gymnasium as gym
import pathlib
import io
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback, PyTorchObs, TensorDict
from stable_baselines3.common.policies import BasePolicy, ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.vec_env import VecEnvInstall with Tessl CLI
npx tessl i tessl/pypi-stable-baselines3