CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Overview
Eval results
Files

training-utilities.mddocs/

Training Utilities

Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes. These components provide essential functionality for experiment management, hyperparameter tuning, and production deployment of RL systems.

Capabilities

Callback System

Event-driven system for monitoring, evaluating, and controlling training processes with customizable hooks at various training stages.

class BaseCallback:
    """
    Abstract base class for training callbacks.
    
    Args:
        verbose: Verbosity level (0: quiet, 1: info, 2: debug)
    """
    def __init__(self, verbose: int = 0): ...

    def init_callback(self, model: "BaseAlgorithm") -> None:
        """Initialize callback with algorithm instance."""

    def on_training_start(
        self, locals_: Dict[str, Any], globals_: Dict[str, Any]
    ) -> None:
        """Called when training begins."""

    def on_rollout_start(self) -> None:
        """Called before collecting rollouts."""

    def on_step(self) -> bool:
        """
        Called after each environment step.
        
        Returns:
            True to continue training, False to stop
        """

    def on_rollout_end(self) -> None:
        """Called after rollout collection."""

    def on_training_end(self) -> None:
        """Called when training ends."""

    def update_locals(self, locals_: Dict[str, Any]) -> None:
        """Update callback with current local variables."""

class EventCallback(BaseCallback):
    """
    Base class for event-triggered callbacks.
    
    Args:
        callback: Child callback to trigger
        verbose: Verbosity level
    """
    def __init__(self, callback: Optional["BaseCallback"] = None, verbose: int = 0): ...

    def _trigger_event(self) -> bool:
        """Trigger child callback if conditions are met."""

    def _on_event(self) -> bool:
        """Event handler (to be implemented by subclasses)."""

class CallbackList(BaseCallback):
    """
    Container for multiple callbacks.
    
    Args:
        callbacks: List of callback instances
    """
    def __init__(self, callbacks: List[BaseCallback]): ...

    def on_training_start(
        self, locals_: Dict[str, Any], globals_: Dict[str, Any]
    ) -> None:
        """Call on_training_start for all callbacks."""

    def on_step(self) -> bool:
        """Call on_step for all callbacks, stop if any returns False."""

class EvalCallback(EventCallback):
    """
    Evaluate agent during training and save best model.
    
    Args:
        eval_env: Environment for evaluation
        callback_on_new_best: Callback triggered when new best model found
        callback_after_eval: Callback triggered after evaluation
        n_eval_episodes: Number of episodes for evaluation
        eval_freq: Evaluation frequency (steps)
        log_path: Path for evaluation logs
        best_model_save_path: Path to save best model
        deterministic: Use deterministic actions during evaluation
        render: Render evaluation episodes
        verbose: Verbosity level
        warn: Show warnings for evaluation issues
    """
    def __init__(
        self,
        eval_env: Union[gym.Env, VecEnv],
        callback_on_new_best: Optional[BaseCallback] = None,
        callback_after_eval: Optional[BaseCallback] = None,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        log_path: Optional[str] = None,
        best_model_save_path: Optional[str] = None,
        deterministic: bool = True,
        render: bool = False,
        verbose: int = 1,
        warn: bool = True,
    ): ...

    def _on_step(self) -> bool:
        """Evaluate model if eval_freq steps have passed."""

    def _on_event(self) -> bool:
        """Perform evaluation and save best model."""

class CheckpointCallback(BaseCallback):
    """
    Save model at regular intervals.
    
    Args:
        save_freq: Frequency for saving checkpoints (steps)
        save_path: Directory to save checkpoints
        name_prefix: Prefix for checkpoint filenames
        save_replay_buffer: Whether to save replay buffer
        save_vecnormalize: Whether to save VecNormalize statistics
        verbose: Verbosity level
    """
    def __init__(
        self,
        save_freq: int,
        save_path: str,
        name_prefix: str = "rl_model",
        save_replay_buffer: bool = False,
        save_vecnormalize: bool = False,
        verbose: int = 0,
    ): ...

    def _on_step(self) -> bool:
        """Save checkpoint if save_freq steps have passed."""

class StopTrainingOnRewardThreshold(BaseCallback):
    """
    Stop training when reward threshold is reached.
    
    Args:
        reward_threshold: Minimum average reward to stop training
        verbose: Verbosity level
    """
    def __init__(self, reward_threshold: float, verbose: int = 0): ...

    def _on_step(self) -> bool:
        """Check if reward threshold is reached."""

class StopTrainingOnMaxEpisodes(BaseCallback):
    """
    Stop training after maximum number of episodes.
    
    Args:
        max_episodes: Maximum number of episodes
        verbose: Verbosity level
    """
    def __init__(self, max_episodes: int, verbose: int = 0): ...

    def _on_step(self) -> bool:
        """Check if maximum episodes reached."""

class ProgressBarCallback(BaseCallback):
    """
    Display training progress bar using tqdm.
    
    Args:
        refresh_freq: Progress bar refresh frequency
    """
    def __init__(self, refresh_freq: int = 1): ...

    def on_training_start(
        self, locals_: Dict[str, Any], globals_: Dict[str, Any]
    ) -> None:
        """Initialize progress bar."""

    def _on_step(self) -> bool:
        """Update progress bar."""

class EveryNTimesteps(EventCallback):
    """
    Trigger a callback every n timesteps.
    
    Args:
        n_steps: Number of timesteps between triggers
        callback: Callback to trigger
    """
    def __init__(self, n_steps: int, callback: BaseCallback): ...

class ConvertCallback(BaseCallback):
    """
    Convert functional callback (old-style) to object.
    
    Args:
        callback: Optional callback function
        verbose: Verbosity level
    """
    def __init__(self, callback: Optional[Callable], verbose: int = 0): ...

class StopTrainingOnNoModelImprovement(BaseCallback):
    """
    Stop training if no new best model after N consecutive evaluations.
    Must be used with EvalCallback.
    
    Args:
        max_no_improvement_evals: Max consecutive evaluations without improvement
        min_evals: Number of evaluations before counting
        verbose: Verbosity level
    """
    def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): ...

Noise Classes

Action noise generators for exploration in continuous control environments, providing various stochastic processes for effective exploration strategies.

class ActionNoise:
    """Abstract base class for action noise."""

    def __call__(self) -> np.ndarray:
        """Generate noise sample."""

    def reset(self) -> None:
        """Reset noise state."""

class NormalActionNoise(ActionNoise):
    """
    Gaussian action noise for exploration.
    
    Args:
        mean: Mean of the noise distribution
        sigma: Standard deviation of the noise distribution
    """
    def __init__(self, mean: np.ndarray, sigma: np.ndarray): ...

    def __call__(self) -> np.ndarray:
        """Sample from Gaussian distribution."""

    def reset(self) -> None:
        """Reset noise (no-op for memoryless noise)."""

class OrnsteinUhlenbeckActionNoise(ActionNoise):
    """
    Ornstein-Uhlenbeck process noise for temporally correlated exploration.
    
    Args:
        mean: Long-run mean of the process
        sigma: Volatility parameter
        theta: Rate of mean reversion
        dt: Time step
        initial_noise: Initial noise value
    """
    def __init__(
        self,
        mean: np.ndarray,
        sigma: np.ndarray,
        theta: float = 0.15,
        dt: float = 1e-2,
        initial_noise: Optional[np.ndarray] = None,
    ): ...

    def __call__(self) -> np.ndarray:
        """Sample next noise value from OU process."""

    def reset(self) -> None:
        """Reset process to initial state."""

class VectorizedActionNoise(ActionNoise):
    """
    Vectorized noise for multiple environments.
    
    Args:
        noise_fn: Function to create noise instances
        n_envs: Number of environments
    """
    def __init__(self, noise_fn: Callable[[], ActionNoise], n_envs: int): ...

    def __call__(self) -> np.ndarray:
        """Sample noise for all environments."""

    def reset(self) -> None:
        """Reset all noise instances."""

Evaluation Functions

Comprehensive evaluation utilities for assessing trained agents, including statistical analysis and performance monitoring across multiple episodes.

def evaluate_policy(
    model: "BaseAlgorithm",
    env: Union[gym.Env, VecEnv],
    n_eval_episodes: int = 10,
    deterministic: bool = True,
    render: bool = False,
    callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
    reward_threshold: Optional[float] = None,
    return_episode_rewards: bool = False,
    warn: bool = True,
    verbose: int = 1,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
    """
    Evaluate trained agent on environment.
    
    Args:
        model: Trained RL algorithm
        env: Environment for evaluation
        n_eval_episodes: Number of episodes to evaluate
        deterministic: Use deterministic actions
        render: Render environment during evaluation
        callback: Custom callback function
        reward_threshold: Minimum reward threshold for success
        return_episode_rewards: Return individual episode rewards
        warn: Show warnings
        verbose: Verbosity level
        
    Returns:
        Mean reward and standard deviation, or episode rewards and lengths
    """

Utility Functions

General-purpose utilities for reproducibility, tensor operations, mathematical computations, and other common tasks in RL training.

def set_random_seed(seed: int, using_cuda: bool = False) -> None:
    """
    Set random seed for reproducibility.
    
    Args:
        seed: Random seed value
        using_cuda: Whether CUDA is being used
    """

def get_system_info() -> Dict[str, Any]:
    """
    Get system information for debugging.
    
    Returns:
        Dictionary containing system information
    """

def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
    """
    Get PyTorch device from string specification.
    
    Args:
        device: Device specification ("auto", "cpu", "cuda", etc.)
        
    Returns:
        PyTorch device object
    """

def obs_as_tensor(
    obs: Union[np.ndarray, Dict[str, np.ndarray]], device: torch.device
) -> Union[torch.Tensor, TensorDict]:
    """
    Convert observations to PyTorch tensors.
    
    Args:
        obs: Observations to convert
        device: Target device for tensors
        
    Returns:
        Tensor or dictionary of tensors
    """

def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
    """
    Calculate explained variance.
    
    Args:
        y_pred: Predicted values
        y_true: True values
        
    Returns:
        Explained variance
    """

def polyak_update(
    params: Iterable[torch.nn.Parameter],
    target_params: Iterable[torch.nn.Parameter],
    tau: float,
) -> None:
    """
    Polyak averaging for target network updates.
    
    Args:
        params: Source parameters
        target_params: Target parameters to update
        tau: Update coefficient (0 = no update, 1 = hard update)
    """

def update_learning_rate(optimizer: torch.optim.Optimizer, learning_rate: float) -> None:
    """
    Update optimizer learning rate.
    
    Args:
        optimizer: PyTorch optimizer
        learning_rate: New learning rate
    """

def safe_mean(arr: List[float]) -> np.ndarray:
    """
    Calculate mean safely, handling empty arrays.
    
    Args:
        arr: Array of values
        
    Returns:
        Mean value or NaN if array is empty
    """

Schedule Functions

Learning rate and parameter scheduling utilities for adaptive training dynamics and improved convergence behavior.

def get_schedule_fn(value_schedule: Union[float, str, Schedule]) -> Schedule:
    """
    Convert value to schedule function.
    
    Args:
        value_schedule: Constant value, string identifier, or schedule function
        
    Returns:
        Schedule function
    """

def get_linear_fn(
    start: float, end: float, end_fraction: float
) -> Callable[[float], float]:
    """
    Create linear schedule function.
    
    Args:
        start: Initial value
        end: Final value
        end_fraction: Fraction of training when end value is reached
        
    Returns:
        Linear schedule function
    """

def constant_fn(val: float) -> Callable[[float], float]:
    """
    Create constant schedule function.
    
    Args:
        val: Constant value
        
    Returns:
        Constant schedule function
    """

Usage Examples

Comprehensive Training Setup

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
    EvalCallback, CheckpointCallback, StopTrainingOnRewardThreshold
)
from stable_baselines3.common.vec_env import DummyVecEnv

# Create training and evaluation environments
train_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])

# Set up callbacks
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/best_model/",
    log_path="./logs/results/",
    eval_freq=10000,
    n_eval_episodes=5,
    deterministic=True,
    render=False
)

checkpoint_callback = CheckpointCallback(
    save_freq=50000,
    save_path="./logs/checkpoints/",
    name_prefix="ppo_cartpole",
    save_replay_buffer=False,
    save_vecnormalize=False
)

stop_callback = StopTrainingOnRewardThreshold(
    reward_threshold=195.0,  # CartPole-v1 is considered solved at 195
    verbose=1
)

# Combine callbacks
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])

# Train with callbacks
model = PPO("MlpPolicy", train_env, verbose=1)
model.learn(total_timesteps=100000, callback=callback)

Action Noise for Continuous Control

import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

# Gaussian noise
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
    mean=np.zeros(n_actions),
    sigma=0.1 * np.ones(n_actions)
)

# Ornstein-Uhlenbeck noise for temporally correlated exploration
action_noise = OrnsteinUhlenbeckActionNoise(
    mean=np.zeros(n_actions),
    sigma=0.1 * np.ones(n_actions),
    theta=0.15,
    dt=1e-2
)

model = TD3(
    "MlpPolicy",
    env,
    action_noise=action_noise,
    verbose=1
)

model.learn(total_timesteps=100000)

Custom Evaluation and Analysis

from stable_baselines3.common.evaluation import evaluate_policy
import matplotlib.pyplot as plt

# Detailed evaluation
episode_rewards, episode_lengths = evaluate_policy(
    model,
    eval_env,
    n_eval_episodes=100,
    deterministic=True,
    return_episode_rewards=True
)

# Statistical analysis
print(f"Mean reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Mean episode length: {np.mean(episode_lengths):.2f}")

# Plot results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.hist(episode_rewards, bins=20)
plt.xlabel("Episode Reward")
plt.ylabel("Frequency")

plt.subplot(1, 2, 2)
plt.hist(episode_lengths, bins=20)
plt.xlabel("Episode Length")
plt.ylabel("Frequency")
plt.show()

Learning Rate Scheduling

from stable_baselines3.common.utils import get_linear_fn

# Linear learning rate decay
learning_rate = get_linear_fn(3e-4, 1e-5, 1.0)

model = PPO(
    "MlpPolicy",
    env,
    learning_rate=learning_rate,
    verbose=1
)

# Custom schedule function
def custom_schedule(progress_remaining):
    """Custom learning rate schedule."""
    if progress_remaining > 0.5:
        return 3e-4
    else:
        return 1e-4

model = PPO(
    "MlpPolicy",
    env,
    learning_rate=custom_schedule,
    verbose=1
)

Custom Callback Creation

from stable_baselines3.common.callbacks import BaseCallback

class LoggingCallback(BaseCallback):
    """Custom callback for additional logging."""
    
    def __init__(self, verbose=0):
        super(LoggingCallback, self).__init__(verbose)
        self.episode_rewards = []
    
    def _on_step(self) -> bool:
        # Access training variables
        if len(self.model.ep_info_buffer) > 0:
            mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer])
            self.logger.record("custom/mean_episode_reward", mean_reward)
        
        return True  # Continue training

# Use custom callback
custom_callback = LoggingCallback(verbose=1)
model.learn(total_timesteps=100000, callback=custom_callback)

Types

from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
import numpy as np
import gymnasium as gym
from stable_baselines3.common.callbacks import BaseCallback, EventCallback, CallbackList
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.noise import ActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.base_class import BaseAlgorithm

Install with Tessl CLI

npx tessl i tessl/pypi-stable-baselines3

docs

algorithms.md

common-framework.md

environments.md

her.md

index.md

training-utilities.md

tile.json