Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks. These components enable efficient training across multiple environment instances and provide essential functionality for production RL systems.
Foundation classes for creating vectorized environments that enable parallel execution and consistent interfaces across different parallelization strategies.
class VecEnv:
"""
Abstract base class for vectorized environments.
Args:
num_envs: Number of environments
observation_space: Single environment observation space
action_space: Single environment action space
"""
def __init__(
self,
num_envs: int,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
): ...
def reset(self) -> VecEnvObs:
"""
Reset all environments.
Returns:
Observations from all environments
"""
def step_async(self, actions: np.ndarray) -> None:
"""
Tell environments to start stepping with given actions.
Args:
actions: Actions for each environment
"""
def step_wait(self) -> VecEnvStepReturn:
"""
Wait for environments to finish stepping.
Returns:
Tuple of (observations, rewards, dones, infos)
"""
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step all environments synchronously.
Args:
actions: Actions for each environment
Returns:
Tuple of (observations, rewards, dones, infos)
"""
def close(self) -> None:
"""Close all environments."""
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""
Get attribute from environments.
Args:
attr_name: Name of attribute to get
indices: Environment indices (None for all)
Returns:
List of attribute values
"""
def set_attr(
self, attr_name: str, value: Any, indices: VecEnvIndices = None
) -> None:
"""
Set attribute in environments.
Args:
attr_name: Name of attribute to set
value: Value to set
indices: Environment indices (None for all)
"""
def env_method(
self,
method_name: str,
*method_args,
indices: VecEnvIndices = None,
**method_kwargs,
) -> List[Any]:
"""
Call method on environments.
Args:
method_name: Name of method to call
*method_args: Positional arguments for method
indices: Environment indices (None for all)
**method_kwargs: Keyword arguments for method
Returns:
List of method return values
"""
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
"""
Set random seed for environments.
Args:
seed: Random seed
Returns:
List of seeds used by each environment
"""
def render(self, mode: str = "human") -> Optional[np.ndarray]:
"""
Render environments.
Args:
mode: Rendering mode
Returns:
Rendered images if mode is 'rgb_array'
"""Simple vectorized environment that runs environments sequentially in the same process, suitable for lightweight environments and debugging.
class DummyVecEnv(VecEnv):
"""
Sequential vectorized environment.
Args:
env_fns: List of functions that create environments
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]): ...
def reset(self) -> VecEnvObs:
"""Reset all environments sequentially."""
def step_async(self, actions: np.ndarray) -> None:
"""Store actions for stepping."""
def step_wait(self) -> VecEnvStepReturn:
"""Step all environments sequentially."""
def close(self) -> None:
"""Close all environments."""
def render(self, mode: str = "human") -> Optional[np.ndarray]:
"""Render first environment."""Vectorized environment that runs environments in separate processes for true parallelization, ideal for computationally expensive environments.
class SubprocVecEnv(VecEnv):
"""
Multiprocessing vectorized environment.
Args:
env_fns: List of functions that create environments
start_method: Multiprocessing start method ('spawn', 'fork', 'forkserver')
"""
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
start_method: Optional[str] = None,
): ...
def reset(self) -> VecEnvObs:
"""Reset all environments in parallel."""
def step_async(self, actions: np.ndarray) -> None:
"""Send actions to worker processes."""
def step_wait(self) -> VecEnvStepReturn:
"""Collect results from worker processes."""
def close(self) -> None:
"""Close all worker processes."""
def render(self, mode: str = "human") -> Optional[np.ndarray]:
"""Render first environment."""Base class and common wrappers for adding functionality to vectorized environments while maintaining the vectorized interface.
class VecEnvWrapper(VecEnv):
"""
Base class for vectorized environment wrappers.
Args:
venv: Vectorized environment to wrap
"""
def __init__(self, venv: VecEnv): ...
def reset(self) -> VecEnvObs:
"""Reset wrapped environment."""
def step_async(self, actions: np.ndarray) -> None:
"""Forward step_async to wrapped environment."""
def step_wait(self) -> VecEnvStepReturn:
"""Forward step_wait to wrapped environment."""
def close(self) -> None:
"""Close wrapped environment."""
class VecNormalize(VecEnvWrapper):
"""
Normalize observations and rewards using running statistics.
Args:
venv: Vectorized environment to wrap
training: Whether in training mode (updates statistics)
norm_obs: Whether to normalize observations
norm_reward: Whether to normalize rewards
clip_obs: Observation clipping range
clip_reward: Reward clipping range
gamma: Discount factor for reward normalization
epsilon: Small constant for numerical stability
norm_obs_keys: Observation keys to normalize (for dict obs)
"""
def __init__(
self,
venv: VecEnv,
training: bool = True,
norm_obs: bool = True,
norm_reward: bool = True,
clip_obs: float = 10.0,
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
norm_obs_keys: Optional[List[str]] = None,
): ...
def normalize_obs(self, obs: VecEnvObs) -> VecEnvObs:
"""
Normalize observations using running statistics.
Args:
obs: Observations to normalize
Returns:
Normalized observations
"""
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using running statistics.
Args:
reward: Rewards to normalize
Returns:
Normalized rewards
"""
def get_original_obs(self) -> Optional[VecEnvObs]:
"""Get unnormalized observations."""
def get_original_reward(self) -> Optional[np.ndarray]:
"""Get unnormalized rewards."""
def reset(self) -> VecEnvObs:
"""Reset and normalize observations."""
def step_wait(self) -> VecEnvStepReturn:
"""Step and normalize observations/rewards."""
class VecFrameStack(VecEnvWrapper):
"""
Stack frames for recurrent policies or temporal information.
Args:
venv: Vectorized environment to wrap
n_stack: Number of frames to stack
channels_order: Channel order ('last' or 'first')
"""
def __init__(
self,
venv: VecEnv,
n_stack: int,
channels_order: str = "last",
): ...
def reset(self) -> VecEnvObs:
"""Reset and initialize frame stack."""
def step_wait(self) -> VecEnvStepReturn:
"""Step and update frame stack."""
class VecTransposeImage(VecEnvWrapper):
"""
Transpose image observations from (H, W, C) to (C, H, W).
Args:
venv: Vectorized environment to wrap
skip: Skip transposition (for debugging)
"""
def __init__(self, venv: VecEnv, skip: bool = False): ...
class VecMonitor(VecEnvWrapper):
"""
Monitor wrapper for vectorized environments.
Args:
venv: Vectorized environment to wrap
filename: Path to log file (None for no logging)
info_keywords: Info dict keys to log
"""
def __init__(
self,
venv: VecEnv,
filename: Optional[str] = None,
info_keywords: Tuple[str, ...] = (),
): ...
class VecCheckNan(VecEnvWrapper):
"""
Check for NaN values in observations, rewards, and actions.
Args:
venv: Vectorized environment to wrap
raise_exception: Whether to raise exception on NaN detection
warn_once: Whether to warn only once per NaN type
"""
def __init__(
self,
venv: VecEnv,
raise_exception: bool = False,
warn_once: bool = True,
): ...
class VecExtractDictObs(VecEnvWrapper):
"""
Extract specific key from dictionary observations.
Args:
venv: Vectorized environment to wrap
key: Dictionary key to extract
"""
def __init__(self, venv: VecEnv, key: str): ...
class VecVideoRecorder(VecEnvWrapper):
"""
Record videos from vectorized environments.
Args:
venv: Vectorized environment to wrap
video_folder: Directory to save videos
record_video_trigger: Function determining when to record
video_length: Length of recorded videos
name_prefix: Prefix for video filenames
"""
def __init__(
self,
venv: VecEnv,
video_folder: str,
record_video_trigger: Callable[[int], bool],
video_length: int = 200,
name_prefix: str = "rl-video",
): ...Additional utilities for environment management, monitoring, and validation that complement the vectorized environment system.
class Monitor(gym.Wrapper):
"""
Environment wrapper for logging episode statistics.
Args:
env: Environment to wrap
filename: Path to log file (None for no logging)
allow_early_resets: Allow resetting before episode completion
reset_keywords: Keywords to log from reset info
info_keywords: Keywords to log from step info
override_existing: Whether to override existing log file
"""
def __init__(
self,
env: gym.Env,
filename: Optional[str] = None,
allow_early_resets: bool = True,
reset_keywords: Tuple[str, ...] = (),
info_keywords: Tuple[str, ...] = (),
override_existing: bool = True,
): ...
def reset(self, **kwargs) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Reset environment and log episode statistics."""
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
"""Step environment and log statistics."""
def make_vec_env(
env_id: Union[str, Callable[[], gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Type[VecEnv] = DummyVecEnv,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create vectorized environment with optional monitoring and wrappers.
Args:
env_id: Environment ID or environment creation function
n_envs: Number of environments
seed: Random seed for environments
start_index: Starting index for environment seeds
monitor_dir: Directory for Monitor logs
wrapper_class: Environment wrapper class
env_kwargs: Arguments for environment creation
vec_env_cls: Vectorized environment class
vec_env_kwargs: Arguments for vectorized environment
monitor_kwargs: Arguments for Monitor wrapper
wrapper_kwargs: Arguments for environment wrapper
Returns:
Vectorized environment
"""
def check_env(
env: gym.Env,
warn: bool = True,
skip_render_check: bool = True,
) -> None:
"""
Check environment compliance with Gym interface.
Args:
env: Environment to check
warn: Whether to show warnings
skip_render_check: Skip render method checking
"""import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
# Sequential vectorization (single process)
env_fns = [lambda: gym.make("CartPole-v1") for _ in range(4)]
vec_env = DummyVecEnv(env_fns)
# Parallel vectorization (multiprocessing)
vec_env = SubprocVecEnv(env_fns)
# Use with algorithm
from stable_baselines3 import PPO
model = PPO("MlpPolicy", vec_env, verbose=1)from stable_baselines3.common.vec_env import VecNormalize
# Create and wrap environment
vec_env = DummyVecEnv([lambda: gym.make("Pendulum-v1") for _ in range(4)])
vec_env = VecNormalize(
vec_env,
norm_obs=True,
norm_reward=True,
clip_obs=10.0,
clip_reward=10.0,
)
# Train with normalization
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=10000)
# Save normalization statistics
vec_env.save("vecnormalize.pkl")
# Load for evaluation
vec_env = VecNormalize.load("vecnormalize.pkl", vec_env)
vec_env.training = False # Disable updates during evaluationfrom stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
# Create Atari environment with frame stacking
env_fns = [lambda: gym.make("BreakoutNoFrameskip-v4") for _ in range(4)]
vec_env = DummyVecEnv(env_fns)
# Transpose images for CNN (H,W,C) -> (C,H,W)
vec_env = VecTransposeImage(vec_env)
# Stack 4 frames
vec_env = VecFrameStack(vec_env, n_stack=4)
model = PPO("CnnPolicy", vec_env, verbose=1)from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3.common.monitor import Monitor
# Single environment monitoring
env = Monitor(gym.make("CartPole-v1"), "training.log")
# Vectorized environment monitoring
vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])
vec_env = VecMonitor(vec_env, "vec_training.log")
# Load monitoring results
from stable_baselines3.common.monitor import load_results
import pandas as pd
results = load_results("training.log")
print(f"Mean reward: {results['r'].mean():.2f}")from stable_baselines3.common.vec_env import make_vec_env
# Create multiple environments with monitoring
vec_env = make_vec_env(
"CartPole-v1",
n_envs=4,
seed=42,
monitor_dir="logs/",
vec_env_cls=SubprocVecEnv,
)
# Custom environment function
def make_custom_env():
env = gym.make("CartPole-v1")
# Add custom preprocessing here
return env
vec_env = make_vec_env(
make_custom_env,
n_envs=4,
vec_env_cls=DummyVecEnv,
)Environment utility functions for wrapper management and synchronization:
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
"""
Retrieve a VecEnvWrapper object by recursively searching.
Args:
env: The VecEnv that is going to be unwrapped
vec_wrapper_class: The desired VecEnvWrapper class
Returns:
The VecEnvWrapper object if found, None otherwise
"""
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
"""
Retrieve a VecNormalize object by recursively searching.
Args:
env: The VecEnv that is going to be unwrapped
Returns:
The VecNormalize object if found, None otherwise
"""
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped in a given VecEnvWrapper.
Args:
env: The VecEnv that is going to be checked
vec_wrapper_class: The desired VecEnvWrapper class
Returns:
True if wrapped with the desired wrapper, False otherwise
"""
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Synchronize the normalization statistics of train and eval environments
when both are wrapped in VecNormalize.
Args:
env: Training environment
eval_env: Environment used for evaluation
"""from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple, Sequence
import numpy as np
import gymnasium as gym
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper, DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.type_aliases import GymEnvInstall with Tessl CLI
npx tessl i tessl/pypi-stable-baselines3