CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Overview
Eval results
Files

environments.mddocs/

Vectorized Environments

Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks. These components enable efficient training across multiple environment instances and provide essential functionality for production RL systems.

Capabilities

Vectorized Environment Base Classes

Foundation classes for creating vectorized environments that enable parallel execution and consistent interfaces across different parallelization strategies.

class VecEnv:
    """
    Abstract base class for vectorized environments.
    
    Args:
        num_envs: Number of environments
        observation_space: Single environment observation space
        action_space: Single environment action space
    """
    def __init__(
        self,
        num_envs: int,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
    ): ...

    def reset(self) -> VecEnvObs:
        """
        Reset all environments.
        
        Returns:
            Observations from all environments
        """

    def step_async(self, actions: np.ndarray) -> None:
        """
        Tell environments to start stepping with given actions.
        
        Args:
            actions: Actions for each environment
        """

    def step_wait(self) -> VecEnvStepReturn:
        """
        Wait for environments to finish stepping.
        
        Returns:
            Tuple of (observations, rewards, dones, infos)
        """

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        """
        Step all environments synchronously.
        
        Args:
            actions: Actions for each environment
            
        Returns:
            Tuple of (observations, rewards, dones, infos)
        """

    def close(self) -> None:
        """Close all environments."""

    def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
        """
        Get attribute from environments.
        
        Args:
            attr_name: Name of attribute to get
            indices: Environment indices (None for all)
            
        Returns:
            List of attribute values
        """

    def set_attr(
        self, attr_name: str, value: Any, indices: VecEnvIndices = None
    ) -> None:
        """
        Set attribute in environments.
        
        Args:
            attr_name: Name of attribute to set
            value: Value to set
            indices: Environment indices (None for all)
        """

    def env_method(
        self,
        method_name: str,
        *method_args,
        indices: VecEnvIndices = None,
        **method_kwargs,
    ) -> List[Any]:
        """
        Call method on environments.
        
        Args:
            method_name: Name of method to call
            *method_args: Positional arguments for method
            indices: Environment indices (None for all)
            **method_kwargs: Keyword arguments for method
            
        Returns:
            List of method return values
        """

    def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
        """
        Set random seed for environments.
        
        Args:
            seed: Random seed
            
        Returns:
            List of seeds used by each environment
        """

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        """
        Render environments.
        
        Args:
            mode: Rendering mode
            
        Returns:
            Rendered images if mode is 'rgb_array'
        """

Sequential Vectorized Environment

Simple vectorized environment that runs environments sequentially in the same process, suitable for lightweight environments and debugging.

class DummyVecEnv(VecEnv):
    """
    Sequential vectorized environment.
    
    Args:
        env_fns: List of functions that create environments
    """
    def __init__(self, env_fns: List[Callable[[], gym.Env]]): ...

    def reset(self) -> VecEnvObs:
        """Reset all environments sequentially."""

    def step_async(self, actions: np.ndarray) -> None:
        """Store actions for stepping."""

    def step_wait(self) -> VecEnvStepReturn:
        """Step all environments sequentially."""

    def close(self) -> None:
        """Close all environments."""

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        """Render first environment."""

Multiprocessing Vectorized Environment

Vectorized environment that runs environments in separate processes for true parallelization, ideal for computationally expensive environments.

class SubprocVecEnv(VecEnv):
    """
    Multiprocessing vectorized environment.
    
    Args:
        env_fns: List of functions that create environments
        start_method: Multiprocessing start method ('spawn', 'fork', 'forkserver')
    """
    def __init__(
        self,
        env_fns: List[Callable[[], gym.Env]],
        start_method: Optional[str] = None,
    ): ...

    def reset(self) -> VecEnvObs:
        """Reset all environments in parallel."""

    def step_async(self, actions: np.ndarray) -> None:
        """Send actions to worker processes."""

    def step_wait(self) -> VecEnvStepReturn:
        """Collect results from worker processes."""

    def close(self) -> None:
        """Close all worker processes."""

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        """Render first environment."""

Vectorized Environment Wrappers

Base class and common wrappers for adding functionality to vectorized environments while maintaining the vectorized interface.

class VecEnvWrapper(VecEnv):
    """
    Base class for vectorized environment wrappers.
    
    Args:
        venv: Vectorized environment to wrap
    """
    def __init__(self, venv: VecEnv): ...

    def reset(self) -> VecEnvObs:
        """Reset wrapped environment."""

    def step_async(self, actions: np.ndarray) -> None:
        """Forward step_async to wrapped environment."""

    def step_wait(self) -> VecEnvStepReturn:
        """Forward step_wait to wrapped environment."""

    def close(self) -> None:
        """Close wrapped environment."""

class VecNormalize(VecEnvWrapper):
    """
    Normalize observations and rewards using running statistics.
    
    Args:
        venv: Vectorized environment to wrap
        training: Whether in training mode (updates statistics)
        norm_obs: Whether to normalize observations
        norm_reward: Whether to normalize rewards
        clip_obs: Observation clipping range
        clip_reward: Reward clipping range
        gamma: Discount factor for reward normalization
        epsilon: Small constant for numerical stability
        norm_obs_keys: Observation keys to normalize (for dict obs)
    """
    def __init__(
        self,
        venv: VecEnv,
        training: bool = True,
        norm_obs: bool = True,
        norm_reward: bool = True,
        clip_obs: float = 10.0,
        clip_reward: float = 10.0,
        gamma: float = 0.99,
        epsilon: float = 1e-8,
        norm_obs_keys: Optional[List[str]] = None,
    ): ...

    def normalize_obs(self, obs: VecEnvObs) -> VecEnvObs:
        """
        Normalize observations using running statistics.
        
        Args:
            obs: Observations to normalize
            
        Returns:
            Normalized observations
        """

    def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
        """
        Normalize rewards using running statistics.
        
        Args:
            reward: Rewards to normalize
            
        Returns:
            Normalized rewards
        """

    def get_original_obs(self) -> Optional[VecEnvObs]:
        """Get unnormalized observations."""

    def get_original_reward(self) -> Optional[np.ndarray]:
        """Get unnormalized rewards."""

    def reset(self) -> VecEnvObs:
        """Reset and normalize observations."""

    def step_wait(self) -> VecEnvStepReturn:
        """Step and normalize observations/rewards."""

class VecFrameStack(VecEnvWrapper):
    """
    Stack frames for recurrent policies or temporal information.
    
    Args:
        venv: Vectorized environment to wrap
        n_stack: Number of frames to stack
        channels_order: Channel order ('last' or 'first')
    """
    def __init__(
        self,
        venv: VecEnv,
        n_stack: int,
        channels_order: str = "last",
    ): ...

    def reset(self) -> VecEnvObs:
        """Reset and initialize frame stack."""

    def step_wait(self) -> VecEnvStepReturn:
        """Step and update frame stack."""

class VecTransposeImage(VecEnvWrapper):
    """
    Transpose image observations from (H, W, C) to (C, H, W).
    
    Args:
        venv: Vectorized environment to wrap
        skip: Skip transposition (for debugging)
    """
    def __init__(self, venv: VecEnv, skip: bool = False): ...

class VecMonitor(VecEnvWrapper):
    """
    Monitor wrapper for vectorized environments.
    
    Args:
        venv: Vectorized environment to wrap
        filename: Path to log file (None for no logging)
        info_keywords: Info dict keys to log
    """
    def __init__(
        self,
        venv: VecEnv,
        filename: Optional[str] = None,
        info_keywords: Tuple[str, ...] = (),
    ): ...

class VecCheckNan(VecEnvWrapper):
    """
    Check for NaN values in observations, rewards, and actions.
    
    Args:
        venv: Vectorized environment to wrap
        raise_exception: Whether to raise exception on NaN detection
        warn_once: Whether to warn only once per NaN type
    """
    def __init__(
        self,
        venv: VecEnv,
        raise_exception: bool = False,
        warn_once: bool = True,
    ): ...

class VecExtractDictObs(VecEnvWrapper):
    """
    Extract specific key from dictionary observations.
    
    Args:
        venv: Vectorized environment to wrap
        key: Dictionary key to extract
    """
    def __init__(self, venv: VecEnv, key: str): ...

class VecVideoRecorder(VecEnvWrapper):
    """
    Record videos from vectorized environments.
    
    Args:
        venv: Vectorized environment to wrap
        video_folder: Directory to save videos
        record_video_trigger: Function determining when to record
        video_length: Length of recorded videos
        name_prefix: Prefix for video filenames
    """
    def __init__(
        self,
        venv: VecEnv,
        video_folder: str,
        record_video_trigger: Callable[[int], bool],
        video_length: int = 200,
        name_prefix: str = "rl-video",
    ): ...

Environment Utilities

Additional utilities for environment management, monitoring, and validation that complement the vectorized environment system.

class Monitor(gym.Wrapper):
    """
    Environment wrapper for logging episode statistics.
    
    Args:
        env: Environment to wrap
        filename: Path to log file (None for no logging)
        allow_early_resets: Allow resetting before episode completion
        reset_keywords: Keywords to log from reset info
        info_keywords: Keywords to log from step info
        override_existing: Whether to override existing log file
    """
    def __init__(
        self,
        env: gym.Env,
        filename: Optional[str] = None,
        allow_early_resets: bool = True,
        reset_keywords: Tuple[str, ...] = (),
        info_keywords: Tuple[str, ...] = (),
        override_existing: bool = True,
    ): ...

    def reset(self, **kwargs) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Reset environment and log episode statistics."""

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        """Step environment and log statistics."""

def make_vec_env(
    env_id: Union[str, Callable[[], gym.Env]],
    n_envs: int = 1,
    seed: Optional[int] = None,
    start_index: int = 0,
    monitor_dir: Optional[str] = None,
    wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
    env_kwargs: Optional[Dict[str, Any]] = None,
    vec_env_cls: Type[VecEnv] = DummyVecEnv,
    vec_env_kwargs: Optional[Dict[str, Any]] = None,
    monitor_kwargs: Optional[Dict[str, Any]] = None,
    wrapper_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
    """
    Create vectorized environment with optional monitoring and wrappers.
    
    Args:
        env_id: Environment ID or environment creation function
        n_envs: Number of environments
        seed: Random seed for environments
        start_index: Starting index for environment seeds
        monitor_dir: Directory for Monitor logs
        wrapper_class: Environment wrapper class
        env_kwargs: Arguments for environment creation
        vec_env_cls: Vectorized environment class
        vec_env_kwargs: Arguments for vectorized environment
        monitor_kwargs: Arguments for Monitor wrapper
        wrapper_kwargs: Arguments for environment wrapper
        
    Returns:
        Vectorized environment
    """

def check_env(
    env: gym.Env,
    warn: bool = True,
    skip_render_check: bool = True,
) -> None:
    """
    Check environment compliance with Gym interface.
    
    Args:
        env: Environment to check
        warn: Whether to show warnings
        skip_render_check: Skip render method checking
    """

Usage Examples

Basic Vectorized Environment Setup

import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

# Sequential vectorization (single process)
env_fns = [lambda: gym.make("CartPole-v1") for _ in range(4)]
vec_env = DummyVecEnv(env_fns)

# Parallel vectorization (multiprocessing)
vec_env = SubprocVecEnv(env_fns)

# Use with algorithm
from stable_baselines3 import PPO
model = PPO("MlpPolicy", vec_env, verbose=1)

Environment Normalization

from stable_baselines3.common.vec_env import VecNormalize

# Create and wrap environment
vec_env = DummyVecEnv([lambda: gym.make("Pendulum-v1") for _ in range(4)])
vec_env = VecNormalize(
    vec_env,
    norm_obs=True,
    norm_reward=True,
    clip_obs=10.0,
    clip_reward=10.0,
)

# Train with normalization
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=10000)

# Save normalization statistics
vec_env.save("vecnormalize.pkl")

# Load for evaluation
vec_env = VecNormalize.load("vecnormalize.pkl", vec_env)
vec_env.training = False  # Disable updates during evaluation

Frame Stacking for Atari

from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage

# Create Atari environment with frame stacking
env_fns = [lambda: gym.make("BreakoutNoFrameskip-v4") for _ in range(4)]
vec_env = DummyVecEnv(env_fns)

# Transpose images for CNN (H,W,C) -> (C,H,W)
vec_env = VecTransposeImage(vec_env)

# Stack 4 frames
vec_env = VecFrameStack(vec_env, n_stack=4)

model = PPO("CnnPolicy", vec_env, verbose=1)

Environment Monitoring

from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3.common.monitor import Monitor

# Single environment monitoring
env = Monitor(gym.make("CartPole-v1"), "training.log")

# Vectorized environment monitoring
vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])
vec_env = VecMonitor(vec_env, "vec_training.log")

# Load monitoring results
from stable_baselines3.common.monitor import load_results
import pandas as pd

results = load_results("training.log")
print(f"Mean reward: {results['r'].mean():.2f}")

Custom Environment Creation

from stable_baselines3.common.vec_env import make_vec_env

# Create multiple environments with monitoring
vec_env = make_vec_env(
    "CartPole-v1",
    n_envs=4,
    seed=42,
    monitor_dir="logs/",
    vec_env_cls=SubprocVecEnv,
)

# Custom environment function
def make_custom_env():
    env = gym.make("CartPole-v1")
    # Add custom preprocessing here
    return env

vec_env = make_vec_env(
    make_custom_env,
    n_envs=4,
    vec_env_cls=DummyVecEnv,
)

Utility Functions

Environment utility functions for wrapper management and synchronization:

def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
    """
    Retrieve a VecEnvWrapper object by recursively searching.
    
    Args:
        env: The VecEnv that is going to be unwrapped
        vec_wrapper_class: The desired VecEnvWrapper class
        
    Returns:
        The VecEnvWrapper object if found, None otherwise
    """

def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
    """
    Retrieve a VecNormalize object by recursively searching.
    
    Args:
        env: The VecEnv that is going to be unwrapped
        
    Returns:
        The VecNormalize object if found, None otherwise
    """

def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool:
    """
    Check if an environment is already wrapped in a given VecEnvWrapper.
    
    Args:
        env: The VecEnv that is going to be checked
        vec_wrapper_class: The desired VecEnvWrapper class
        
    Returns:
        True if wrapped with the desired wrapper, False otherwise
    """

def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
    """
    Synchronize the normalization statistics of train and eval environments
    when both are wrapped in VecNormalize.
    
    Args:
        env: Training environment
        eval_env: Environment used for evaluation
    """

Types

from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple, Sequence
import numpy as np
import gymnasium as gym
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper, DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.type_aliases import GymEnv

Install with Tessl CLI

npx tessl i tessl/pypi-stable-baselines3

docs

algorithms.md

common-framework.md

environments.md

her.md

index.md

training-utilities.md

tile.json