Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes. These components provide essential functionality for experiment management, hyperparameter tuning, and production deployment of RL systems.
Event-driven system for monitoring, evaluating, and controlling training processes with customizable hooks at various training stages.
class BaseCallback:
"""
Abstract base class for training callbacks.
Args:
verbose: Verbosity level (0: quiet, 1: info, 2: debug)
"""
def __init__(self, verbose: int = 0): ...
def init_callback(self, model: "BaseAlgorithm") -> None:
"""Initialize callback with algorithm instance."""
def on_training_start(
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
) -> None:
"""Called when training begins."""
def on_rollout_start(self) -> None:
"""Called before collecting rollouts."""
def on_step(self) -> bool:
"""
Called after each environment step.
Returns:
True to continue training, False to stop
"""
def on_rollout_end(self) -> None:
"""Called after rollout collection."""
def on_training_end(self) -> None:
"""Called when training ends."""
def update_locals(self, locals_: Dict[str, Any]) -> None:
"""Update callback with current local variables."""
class EventCallback(BaseCallback):
"""
Base class for event-triggered callbacks.
Args:
callback: Child callback to trigger
verbose: Verbosity level
"""
def __init__(self, callback: Optional["BaseCallback"] = None, verbose: int = 0): ...
def _trigger_event(self) -> bool:
"""Trigger child callback if conditions are met."""
def _on_event(self) -> bool:
"""Event handler (to be implemented by subclasses)."""
class CallbackList(BaseCallback):
"""
Container for multiple callbacks.
Args:
callbacks: List of callback instances
"""
def __init__(self, callbacks: List[BaseCallback]): ...
def on_training_start(
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
) -> None:
"""Call on_training_start for all callbacks."""
def on_step(self) -> bool:
"""Call on_step for all callbacks, stop if any returns False."""
class EvalCallback(EventCallback):
"""
Evaluate agent during training and save best model.
Args:
eval_env: Environment for evaluation
callback_on_new_best: Callback triggered when new best model found
callback_after_eval: Callback triggered after evaluation
n_eval_episodes: Number of episodes for evaluation
eval_freq: Evaluation frequency (steps)
log_path: Path for evaluation logs
best_model_save_path: Path to save best model
deterministic: Use deterministic actions during evaluation
render: Render evaluation episodes
verbose: Verbosity level
warn: Show warnings for evaluation issues
"""
def __init__(
self,
eval_env: Union[gym.Env, VecEnv],
callback_on_new_best: Optional[BaseCallback] = None,
callback_after_eval: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: Optional[str] = None,
best_model_save_path: Optional[str] = None,
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
): ...
def _on_step(self) -> bool:
"""Evaluate model if eval_freq steps have passed."""
def _on_event(self) -> bool:
"""Perform evaluation and save best model."""
class CheckpointCallback(BaseCallback):
"""
Save model at regular intervals.
Args:
save_freq: Frequency for saving checkpoints (steps)
save_path: Directory to save checkpoints
name_prefix: Prefix for checkpoint filenames
save_replay_buffer: Whether to save replay buffer
save_vecnormalize: Whether to save VecNormalize statistics
verbose: Verbosity level
"""
def __init__(
self,
save_freq: int,
save_path: str,
name_prefix: str = "rl_model",
save_replay_buffer: bool = False,
save_vecnormalize: bool = False,
verbose: int = 0,
): ...
def _on_step(self) -> bool:
"""Save checkpoint if save_freq steps have passed."""
class StopTrainingOnRewardThreshold(BaseCallback):
"""
Stop training when reward threshold is reached.
Args:
reward_threshold: Minimum average reward to stop training
verbose: Verbosity level
"""
def __init__(self, reward_threshold: float, verbose: int = 0): ...
def _on_step(self) -> bool:
"""Check if reward threshold is reached."""
class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop training after maximum number of episodes.
Args:
max_episodes: Maximum number of episodes
verbose: Verbosity level
"""
def __init__(self, max_episodes: int, verbose: int = 0): ...
def _on_step(self) -> bool:
"""Check if maximum episodes reached."""
class ProgressBarCallback(BaseCallback):
"""
Display training progress bar using tqdm.
Args:
refresh_freq: Progress bar refresh frequency
"""
def __init__(self, refresh_freq: int = 1): ...
def on_training_start(
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
) -> None:
"""Initialize progress bar."""
def _on_step(self) -> bool:
"""Update progress bar."""
class EveryNTimesteps(EventCallback):
"""
Trigger a callback every n timesteps.
Args:
n_steps: Number of timesteps between triggers
callback: Callback to trigger
"""
def __init__(self, n_steps: int, callback: BaseCallback): ...
class ConvertCallback(BaseCallback):
"""
Convert functional callback (old-style) to object.
Args:
callback: Optional callback function
verbose: Verbosity level
"""
def __init__(self, callback: Optional[Callable], verbose: int = 0): ...
class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop training if no new best model after N consecutive evaluations.
Must be used with EvalCallback.
Args:
max_no_improvement_evals: Max consecutive evaluations without improvement
min_evals: Number of evaluations before counting
verbose: Verbosity level
"""
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): ...Action noise generators for exploration in continuous control environments, providing various stochastic processes for effective exploration strategies.
class ActionNoise:
"""Abstract base class for action noise."""
def __call__(self) -> np.ndarray:
"""Generate noise sample."""
def reset(self) -> None:
"""Reset noise state."""
class NormalActionNoise(ActionNoise):
"""
Gaussian action noise for exploration.
Args:
mean: Mean of the noise distribution
sigma: Standard deviation of the noise distribution
"""
def __init__(self, mean: np.ndarray, sigma: np.ndarray): ...
def __call__(self) -> np.ndarray:
"""Sample from Gaussian distribution."""
def reset(self) -> None:
"""Reset noise (no-op for memoryless noise)."""
class OrnsteinUhlenbeckActionNoise(ActionNoise):
"""
Ornstein-Uhlenbeck process noise for temporally correlated exploration.
Args:
mean: Long-run mean of the process
sigma: Volatility parameter
theta: Rate of mean reversion
dt: Time step
initial_noise: Initial noise value
"""
def __init__(
self,
mean: np.ndarray,
sigma: np.ndarray,
theta: float = 0.15,
dt: float = 1e-2,
initial_noise: Optional[np.ndarray] = None,
): ...
def __call__(self) -> np.ndarray:
"""Sample next noise value from OU process."""
def reset(self) -> None:
"""Reset process to initial state."""
class VectorizedActionNoise(ActionNoise):
"""
Vectorized noise for multiple environments.
Args:
noise_fn: Function to create noise instances
n_envs: Number of environments
"""
def __init__(self, noise_fn: Callable[[], ActionNoise], n_envs: int): ...
def __call__(self) -> np.ndarray:
"""Sample noise for all environments."""
def reset(self) -> None:
"""Reset all noise instances."""Comprehensive evaluation utilities for assessing trained agents, including statistical analysis and performance monitoring across multiple episodes.
def evaluate_policy(
model: "BaseAlgorithm",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
verbose: int = 1,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Evaluate trained agent on environment.
Args:
model: Trained RL algorithm
env: Environment for evaluation
n_eval_episodes: Number of episodes to evaluate
deterministic: Use deterministic actions
render: Render environment during evaluation
callback: Custom callback function
reward_threshold: Minimum reward threshold for success
return_episode_rewards: Return individual episode rewards
warn: Show warnings
verbose: Verbosity level
Returns:
Mean reward and standard deviation, or episode rewards and lengths
"""General-purpose utilities for reproducibility, tensor operations, mathematical computations, and other common tasks in RL training.
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Set random seed for reproducibility.
Args:
seed: Random seed value
using_cuda: Whether CUDA is being used
"""
def get_system_info() -> Dict[str, Any]:
"""
Get system information for debugging.
Returns:
Dictionary containing system information
"""
def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
"""
Get PyTorch device from string specification.
Args:
device: Device specification ("auto", "cpu", "cuda", etc.)
Returns:
PyTorch device object
"""
def obs_as_tensor(
obs: Union[np.ndarray, Dict[str, np.ndarray]], device: torch.device
) -> Union[torch.Tensor, TensorDict]:
"""
Convert observations to PyTorch tensors.
Args:
obs: Observations to convert
device: Target device for tensors
Returns:
Tensor or dictionary of tensors
"""
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
Calculate explained variance.
Args:
y_pred: Predicted values
y_true: True values
Returns:
Explained variance
"""
def polyak_update(
params: Iterable[torch.nn.Parameter],
target_params: Iterable[torch.nn.Parameter],
tau: float,
) -> None:
"""
Polyak averaging for target network updates.
Args:
params: Source parameters
target_params: Target parameters to update
tau: Update coefficient (0 = no update, 1 = hard update)
"""
def update_learning_rate(optimizer: torch.optim.Optimizer, learning_rate: float) -> None:
"""
Update optimizer learning rate.
Args:
optimizer: PyTorch optimizer
learning_rate: New learning rate
"""
def safe_mean(arr: List[float]) -> np.ndarray:
"""
Calculate mean safely, handling empty arrays.
Args:
arr: Array of values
Returns:
Mean value or NaN if array is empty
"""Learning rate and parameter scheduling utilities for adaptive training dynamics and improved convergence behavior.
def get_schedule_fn(value_schedule: Union[float, str, Schedule]) -> Schedule:
"""
Convert value to schedule function.
Args:
value_schedule: Constant value, string identifier, or schedule function
Returns:
Schedule function
"""
def get_linear_fn(
start: float, end: float, end_fraction: float
) -> Callable[[float], float]:
"""
Create linear schedule function.
Args:
start: Initial value
end: Final value
end_fraction: Fraction of training when end value is reached
Returns:
Linear schedule function
"""
def constant_fn(val: float) -> Callable[[float], float]:
"""
Create constant schedule function.
Args:
val: Constant value
Returns:
Constant schedule function
"""import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
EvalCallback, CheckpointCallback, StopTrainingOnRewardThreshold
)
from stable_baselines3.common.vec_env import DummyVecEnv
# Create training and evaluation environments
train_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
# Set up callbacks
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./logs/best_model/",
log_path="./logs/results/",
eval_freq=10000,
n_eval_episodes=5,
deterministic=True,
render=False
)
checkpoint_callback = CheckpointCallback(
save_freq=50000,
save_path="./logs/checkpoints/",
name_prefix="ppo_cartpole",
save_replay_buffer=False,
save_vecnormalize=False
)
stop_callback = StopTrainingOnRewardThreshold(
reward_threshold=195.0, # CartPole-v1 is considered solved at 195
verbose=1
)
# Combine callbacks
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])
# Train with callbacks
model = PPO("MlpPolicy", train_env, verbose=1)
model.learn(total_timesteps=100000, callback=callback)import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
# Gaussian noise
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
mean=np.zeros(n_actions),
sigma=0.1 * np.ones(n_actions)
)
# Ornstein-Uhlenbeck noise for temporally correlated exploration
action_noise = OrnsteinUhlenbeckActionNoise(
mean=np.zeros(n_actions),
sigma=0.1 * np.ones(n_actions),
theta=0.15,
dt=1e-2
)
model = TD3(
"MlpPolicy",
env,
action_noise=action_noise,
verbose=1
)
model.learn(total_timesteps=100000)from stable_baselines3.common.evaluation import evaluate_policy
import matplotlib.pyplot as plt
# Detailed evaluation
episode_rewards, episode_lengths = evaluate_policy(
model,
eval_env,
n_eval_episodes=100,
deterministic=True,
return_episode_rewards=True
)
# Statistical analysis
print(f"Mean reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Mean episode length: {np.mean(episode_lengths):.2f}")
# Plot results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.hist(episode_rewards, bins=20)
plt.xlabel("Episode Reward")
plt.ylabel("Frequency")
plt.subplot(1, 2, 2)
plt.hist(episode_lengths, bins=20)
plt.xlabel("Episode Length")
plt.ylabel("Frequency")
plt.show()from stable_baselines3.common.utils import get_linear_fn
# Linear learning rate decay
learning_rate = get_linear_fn(3e-4, 1e-5, 1.0)
model = PPO(
"MlpPolicy",
env,
learning_rate=learning_rate,
verbose=1
)
# Custom schedule function
def custom_schedule(progress_remaining):
"""Custom learning rate schedule."""
if progress_remaining > 0.5:
return 3e-4
else:
return 1e-4
model = PPO(
"MlpPolicy",
env,
learning_rate=custom_schedule,
verbose=1
)from stable_baselines3.common.callbacks import BaseCallback
class LoggingCallback(BaseCallback):
"""Custom callback for additional logging."""
def __init__(self, verbose=0):
super(LoggingCallback, self).__init__(verbose)
self.episode_rewards = []
def _on_step(self) -> bool:
# Access training variables
if len(self.model.ep_info_buffer) > 0:
mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer])
self.logger.record("custom/mean_episode_reward", mean_reward)
return True # Continue training
# Use custom callback
custom_callback = LoggingCallback(verbose=1)
model.learn(total_timesteps=100000, callback=custom_callback)from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
import numpy as np
import gymnasium as gym
from stable_baselines3.common.callbacks import BaseCallback, EventCallback, CallbackList
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.noise import ActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.base_class import BaseAlgorithmInstall with Tessl CLI
npx tessl i tessl/pypi-stable-baselines3