A Training Framework for Stable Baselines3 Reinforcement Learning Agents
—
Custom Gymnasium environment wrappers for observation processing, reward modification, action manipulation, and training optimization. These wrappers extend environments with specialized functionality needed for effective RL training.
from rl_zoo3.wrappers import (
TruncatedOnSuccessWrapper,
ActionNoiseWrapper,
ActionSmoothingWrapper,
DelayedRewardWrapper,
HistoryWrapper,
HistoryWrapperObsDict,
FrameSkip,
MaskVelocityWrapper
)
import gymnasium as gym
import numpy as npWrapper that truncates episodes upon achieving success conditions, useful for goal-oriented environments and curriculum learning.
class TruncatedOnSuccessWrapper(gym.Wrapper):
"""
Reset on success and offsets the reward.
Useful for GoalEnv and goal-oriented tasks.
"""
def __init__(
self,
env: gym.Env,
reward_offset: float = 0.0,
n_successes: int = 1
):
"""
Initialize TruncatedOnSuccessWrapper.
Parameters:
- env: Base environment to wrap
- reward_offset: Offset to add to all rewards
- n_successes: Number of consecutive successes needed for truncation
"""
def reset(
self,
seed: Optional[int] = None,
options: Optional[dict] = None
) -> GymResetReturn:
"""Reset environment and success counter."""
def step(self, action) -> GymStepReturn:
"""
Execute action and check for success truncation.
Returns:
Tuple of (observation, reward + offset, terminated, truncated, info)
"""
def compute_reward(self, achieved_goal, desired_goal, info):
"""Compute reward with offset for goal environments."""Usage example:
import gymnasium as gym
from rl_zoo3.wrappers import TruncatedOnSuccessWrapper
# Create base environment
env = gym.make("FetchReach-v1")
# Wrap with success truncation
wrapped_env = TruncatedOnSuccessWrapper(
env,
reward_offset=1.0, # Add bonus reward
n_successes=3 # Require 3 consecutive successes
)
# Use in training
obs, info = wrapped_env.reset()
for step in range(1000):
action = wrapped_env.action_space.sample()
obs, reward, terminated, truncated, info = wrapped_env.step(action)
if truncated and info.get("is_success", False):
print(f"Success achieved at step {step}")
breakWrapper that adds configurable noise to agent actions, useful for exploration and robustness testing.
class ActionNoiseWrapper(gym.Wrapper[ObsType, np.ndarray, ObsType, np.ndarray]):
"""
Wrapper that adds noise to actions.
Useful for exploration and robustness evaluation.
"""
def __init__(
self,
env: gym.Env,
noise_std: float = 0.1,
noise_type: str = "gaussian"
):
"""
Initialize ActionNoiseWrapper.
Parameters:
- env: Base environment to wrap
- noise_std: Standard deviation of noise
- noise_type: Type of noise ('gaussian', 'uniform')
"""
def step(self, action) -> GymStepReturn:
"""
Execute action with added noise.
Parameters:
- action: Original action from agent
Returns:
Environment step result with noisy action applied
"""Usage example:
from rl_zoo3.wrappers import ActionNoiseWrapper
import gymnasium as gym
# Create environment
env = gym.make("Pendulum-v1")
# Add action noise
noisy_env = ActionNoiseWrapper(
env,
noise_std=0.05, # 5% noise
noise_type="gaussian"
)
# Actions will have noise added automatically
obs, info = noisy_env.reset()
action = np.array([0.5]) # Clean action
obs, reward, terminated, truncated, info = noisy_env.step(action) # Noise added internallyWrapper that smooths actions over multiple timesteps, reducing jerkiness in continuous control tasks.
class ActionSmoothingWrapper(gym.Wrapper):
"""
Wrapper for action smoothing over multiple timesteps.
Reduces action jerkiness in continuous control.
"""
def __init__(
self,
env: gym.Env,
smoothing_coef: float = 0.9
):
"""
Initialize ActionSmoothingWrapper.
Parameters:
- env: Base environment to wrap
- smoothing_coef: Smoothing coefficient (0.0 = no smoothing, 1.0 = maximum smoothing)
"""
def step(self, action) -> GymStepReturn:
"""
Execute smoothed action.
Parameters:
- action: Raw action from agent
Returns:
Environment step result with smoothed action
"""
def reset(self, **kwargs) -> GymResetReturn:
"""Reset environment and action history."""Wrapper that delays reward delivery by a specified number of steps, useful for testing credit assignment and memory.
class DelayedRewardWrapper(gym.Wrapper):
"""
Wrapper that delays reward delivery.
Useful for testing credit assignment capabilities.
"""
def __init__(
self,
env: gym.Env,
delay: int = 10
):
"""
Initialize DelayedRewardWrapper.
Parameters:
- env: Base environment to wrap
- delay: Number of steps to delay rewards
"""
def step(self, action) -> GymStepReturn:
"""
Execute action with delayed reward delivery.
Returns:
Step result with current reward set to 0.0, delayed rewards delivered later
"""
def reset(self, **kwargs) -> GymResetReturn:
"""Reset environment and reward buffer."""Usage example:
from rl_zoo3.wrappers import DelayedRewardWrapper
import gymnasium as gym
# Create environment with delayed rewards
env = gym.make("CartPole-v1")
delayed_env = DelayedRewardWrapper(env, delay=5)
# Rewards will be delayed by 5 steps
obs, info = delayed_env.reset()
total_reward = 0
for step in range(100):
action = delayed_env.action_space.sample()
obs, reward, terminated, truncated, info = delayed_env.step(action)
total_reward += reward
if terminated or truncated:
print(f"Episode ended with total reward: {total_reward}")
breakWrapper that maintains a history of observations, useful for partially observable environments and recurrent policies.
class HistoryWrapper(gym.Wrapper[np.ndarray, np.ndarray, np.ndarray, np.ndarray]):
"""
Wrapper that maintains observation history.
Useful for partial observability and recurrent policies.
"""
def __init__(
self,
env: gym.Env,
horizon: int = 2
):
"""
Initialize HistoryWrapper.
Parameters:
- env: Base environment to wrap (must have Box observation space)
- horizon: Number of past observations to include
"""
def reset(self, **kwargs) -> GymResetReturn:
"""Reset environment and observation history."""
def step(self, action) -> GymStepReturn:
"""
Execute action and update observation history.
Returns:
Step result with concatenated observation history
"""Specialized history wrapper for environments with dictionary observation spaces.
class HistoryWrapperObsDict(gym.Wrapper):
"""
History wrapper for dictionary observation spaces.
Maintains separate history for each observation key.
"""
def __init__(
self,
env: gym.Env,
horizon: int = 2
):
"""
Initialize HistoryWrapperObsDict.
Parameters:
- env: Base environment with Dict observation space
- horizon: Number of past observations to maintain per key
"""
def reset(self, **kwargs) -> GymResetReturn:
"""Reset environment and all observation histories."""
def step(self, action) -> GymStepReturn:
"""
Execute action and update all observation histories.
Returns:
Step result with extended dictionary observations
"""Wrapper that skips frames and repeats actions, common in Atari and other environments for computational efficiency.
class FrameSkip(gym.Wrapper):
"""
Wrapper for frame skipping (action repeat).
Repeats actions for multiple frames and returns the final result.
"""
def __init__(
self,
env: gym.Env,
skip: int = 4
):
"""
Initialize FrameSkip wrapper.
Parameters:
- env: Base environment to wrap
- skip: Number of frames to skip (action repeat count)
"""
def step(self, action) -> GymStepReturn:
"""
Execute action for multiple frames.
Parameters:
- action: Action to repeat
Returns:
Result after skipping frames with accumulated reward
"""Wrapper that masks velocity information from observations, useful for testing position-only policies.
class MaskVelocityWrapper(gym.ObservationWrapper):
"""
Wrapper that masks velocity information from observations.
Useful for testing position-only policies.
"""
def __init__(self, env: gym.Env):
"""
Initialize MaskVelocityWrapper.
Parameters:
- env: Base environment (typically MuJoCo-based)
"""
def observation(self, observation) -> np.ndarray:
"""
Mask velocity components from observation.
Parameters:
- observation: Original observation
Returns:
Observation with velocity components set to zero
"""Wrapper for resizing observations with YAML-compatible configuration format.
class YAMLCompatResizeObservation(ResizeObservation):
"""
YAML-compatible version of ResizeObservation wrapper.
Accepts list format for shape specification.
"""
def __init__(self, env: gym.Env, shape: list[int]):
"""
Initialize YAMLCompatResizeObservation.
Parameters:
- env: Base environment to wrap
- shape: Target shape as list [height, width]
"""import gymnasium as gym
from rl_zoo3.wrappers import (
TruncatedOnSuccessWrapper,
ActionNoiseWrapper,
DelayedRewardWrapper,
HistoryWrapper
)
# Create base environment
env = gym.make("FetchReach-v1")
# Apply multiple wrappers (order matters)
env = TruncatedOnSuccessWrapper(env, reward_offset=1.0)
env = ActionNoiseWrapper(env, noise_std=0.05)
env = DelayedRewardWrapper(env, delay=3)
env = HistoryWrapper(env, horizon=4)
# Use wrapped environment
obs, info = env.reset()
for step in range(1000):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()from rl_zoo3.utils import get_wrapper_class
# Configuration dict (typically from hyperparameters file)
hyperparams = {
"env_wrapper": [
{
"rl_zoo3.wrappers:TruncatedOnSuccessWrapper": {
"reward_offset": 1.0,
"n_successes": 2
}
},
{
"rl_zoo3.wrappers:ActionNoiseWrapper": {
"noise_std": 0.1
}
}
]
}
# Get wrapper function from configuration
wrapper_fn = get_wrapper_class(hyperparams)
# Apply wrappers to environment
env = gym.make("FetchReach-v1")
if wrapper_fn is not None:
env = wrapper_fn(env)from rl_zoo3.exp_manager import ExperimentManager
import argparse
# Wrappers are automatically applied based on hyperparameters
args = argparse.Namespace(
algo='sac',
env='Pendulum-v1',
n_timesteps=50000
)
# Hyperparameters with wrapper specifications
hyperparams = {
"env_wrapper": "rl_zoo3.wrappers:ActionSmoothingWrapper",
"env_wrapper_kwargs": {"smoothing_coef": 0.8}
}
exp_manager = ExperimentManager(
args=args,
algo='sac',
env_id='Pendulum-v1',
log_folder='./logs',
hyperparams=hyperparams
)
# Wrappers applied automatically during environment creation
model = exp_manager.setup_experiment()Install with Tessl CLI
npx tessl i tessl/pypi-rl-zoo3