A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym).
—
Environment wrappers modify the behavior of existing environments without changing the underlying implementation. Gymnasium provides a comprehensive set of pre-built wrappers for common transformations including observation processing, action modification, reward shaping, and rendering enhancements.
Fundamental wrappers for basic environment modifications.
class TimeLimit(Wrapper):
"""
Add time limit to episodes.
Args:
env: Environment to wrap
max_episode_steps: Maximum steps per episode
"""
def __init__(self, env: gym.Env, max_episode_steps: int):
pass
class Autoreset(Wrapper):
"""
Automatically reset environment when episode ends.
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
pass
class RecordEpisodeStatistics(Wrapper):
"""
Record episode statistics (length, reward, time).
Args:
env: Environment to wrap
buffer_length: Size of statistics buffers (default: 100)
stats_key: Key for storing episode statistics in info dict
"""
def __init__(self, env: gym.Env, buffer_length: int = 100,
stats_key: str = "episode"):
pass
class OrderEnforcing(Wrapper):
"""
Enforce that reset is called before step.
Args:
env: Environment to wrap
disable_render_order_enforcing: Disable render order enforcement
"""
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
pass
class PassiveEnvChecker(Wrapper):
"""
Check environment compliance with Gymnasium API.
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
passWrappers that transform or modify observations.
class FlattenObservation(ObservationWrapper):
"""
Flatten observation space (useful for Dict/Tuple spaces).
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
pass
class FilterObservation(ObservationWrapper):
"""
Filter keys from Dict observation space.
Args:
env: Environment to wrap
filter_keys: Keys to keep in observation dict
"""
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
pass
class TransformObservation(ObservationWrapper):
"""
Apply custom transformation to observations.
Args:
env: Environment to wrap
func: Function to transform observations
observation_space: New observation space after transformation
"""
def __init__(self, env: gym.Env, func: Callable,
observation_space: gym.Space | None):
pass
class DtypeObservation(ObservationWrapper):
"""
Convert observation data type.
Args:
env: Environment to wrap
dtype: Target data type
"""
def __init__(self, env: gym.Env, dtype: Any):
pass
class ReshapeObservation(ObservationWrapper):
"""
Reshape observation arrays.
Args:
env: Environment to wrap
shape: New shape for observations
"""
def __init__(self, env: gym.Env, shape: int | tuple[int, ...]):
pass
class ResizeObservation(ObservationWrapper):
"""
Resize image observations.
Args:
env: Environment to wrap
shape: New image shape (height, width)
"""
def __init__(self, env: gym.Env, shape: tuple[int, int]):
pass
class GrayscaleObservation(ObservationWrapper):
"""
Convert RGB observations to grayscale.
Args:
env: Environment to wrap
keep_dim: Whether to keep color dimension
"""
def __init__(self, env: gym.Env, keep_dim: bool = False):
pass
class RescaleObservation(ObservationWrapper):
"""
Rescale observation values to a target range.
Args:
env: Environment to wrap
min_obs: Minimum observation value
max_obs: Maximum observation value
"""
def __init__(self, env: gym.Env,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray):
passObservation wrappers that maintain internal state.
class FrameStackObservation(ObservationWrapper):
"""
Stack multiple consecutive frames.
Args:
env: Environment to wrap
stack_size: Number of frames to stack
padding_type: Padding type for initial frames ('zero' or 'reset')
"""
def __init__(self, env: gym.Env, stack_size: int,
padding_type: str = "zero"):
pass
class DelayObservation(ObservationWrapper):
"""
Add delay to observations.
Args:
env: Environment to wrap
delay: Number of steps to delay observations
"""
def __init__(self, env: gym.Env, delay: int):
pass
class NormalizeObservation(ObservationWrapper):
"""
Online normalization of observations.
Args:
env: Environment to wrap
epsilon: Small constant to avoid division by zero
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
pass
class TimeAwareObservation(ObservationWrapper):
"""
Add time information to observations.
Args:
env: Environment to wrap
flatten: Whether to flatten time into observation
"""
def __init__(self, env: gym.Env, flatten: bool = True):
pass
class MaxAndSkipObservation(ObservationWrapper):
"""
Max pooling and frame skipping for Atari-style games.
Args:
env: Environment to wrap
skip: Number of frames to skip
"""
def __init__(self, env: gym.Env, skip: int = 4):
passWrappers that transform or constrain actions.
class ClipAction(ActionWrapper):
"""
Clip actions to valid range for Box action spaces.
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
pass
class RescaleAction(ActionWrapper):
"""
Rescale actions from one range to another.
Args:
env: Environment to wrap
min_action: Minimum action value in new range
max_action: Maximum action value in new range
"""
def __init__(self, env: gym.Env,
min_action: np.floating | np.integer | np.ndarray,
max_action: np.floating | np.integer | np.ndarray):
pass
class TransformAction(ActionWrapper):
"""
Apply custom transformation to actions.
Args:
env: Environment to wrap
func: Function to transform actions
action_space: New action space after transformation
"""
def __init__(self, env: gym.Env, func: Callable,
action_space: Space | None):
pass
class StickyAction(ActionWrapper):
"""
Repeat previous action with some probability.
Args:
env: Environment to wrap
repeat_action_probability: Probability of repeating action
repeat_action_duration: Duration or range for repeating actions
"""
def __init__(self, env: gym.Env, repeat_action_probability: float,
repeat_action_duration: int | tuple[int, int] = 1):
passWrappers that modify reward signals.
class ClipReward(RewardWrapper):
"""
Clip rewards to a specified range.
Args:
env: Environment to wrap
min_reward: Minimum reward value
max_reward: Maximum reward value
"""
def __init__(self, env: gym.Env, min_reward: float, max_reward: float):
pass
class TransformReward(RewardWrapper):
"""
Apply custom transformation to rewards.
Args:
env: Environment to wrap
func: Function to transform rewards
"""
def __init__(self, env: gym.Env, func: Callable):
pass
class NormalizeReward(RewardWrapper):
"""
Online normalization of rewards.
Args:
env: Environment to wrap
gamma: Discount factor for reward normalization
epsilon: Small constant to avoid division by zero
"""
def __init__(self, env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8):
passWrappers for modifying environment rendering.
class RecordVideo(Wrapper):
"""
Record environment episodes as video files.
Args:
env: Environment to wrap
video_folder: Directory to save videos
episode_trigger: Function determining which episodes to record
step_trigger: Function determining which steps to record
video_length: Maximum video length in steps
name_prefix: Prefix for video filenames
fps: Frames per second for video recording
disable_logger: Whether to disable logging
gc_trigger: Function determining when to collect garbage
"""
def __init__(self, env: gym.Env, video_folder: str,
episode_trigger: Callable[[int], bool] | None = None,
step_trigger: Callable[[int], bool] | None = None,
video_length: int = 0, name_prefix: str = "rl-video",
fps: int | None = None, disable_logger: bool = True,
gc_trigger: Callable[[int], bool] | None = lambda episode: True):
pass
class HumanRendering(Wrapper):
"""
Enable human rendering mode for environments.
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
pass
class RenderCollection(Wrapper):
"""
Collect rendered frames from multiple render modes.
Args:
env: Environment to wrap
pop_frames: Whether to clear frames after render
reset_clean: Whether to clear frames on reset
"""
def __init__(self, env: gym.Env, pop_frames: bool = True,
reset_clean: bool = True):
passWrappers for converting between different array libraries (lazy-loaded).
class NumpyToTorch(Wrapper):
"""
Convert NumPy arrays to PyTorch tensors.
Args:
env: Environment to wrap
device: PyTorch device for tensors
"""
def __init__(self, env: gym.Env, device=None):
pass
class JaxToNumpy(Wrapper):
"""
Convert JAX arrays to NumPy arrays.
Args:
env: Environment to wrap
"""
def __init__(self, env: gym.Env):
pass
class JaxToTorch(Wrapper):
"""
Convert JAX arrays to PyTorch tensors.
Args:
env: Environment to wrap
device: PyTorch device for tensors
"""
def __init__(self, env: gym.Env, device=None):
passimport gymnasium as gym
from gymnasium.wrappers import TimeLimit, FlattenObservation, ClipAction
# Create base environment
env = gym.make('LunarLander-v2')
# Add time limit
env = TimeLimit(env, max_episode_steps=500)
# Flatten observations if needed
env = FlattenObservation(env)
# Chain multiple wrappers
env = gym.make('BipedalWalker-v3')
env = ClipAction(env) # Ensure actions are in valid range
env = RecordEpisodeStatistics(env) # Track episode stats
env = TimeLimit(env, max_episode_steps=1000)from gymnasium.wrappers import (
ResizeObservation, GrayscaleObservation,
FrameStackObservation, NormalizeObservation
)
# Create Atari environment with preprocessing pipeline
env = gym.make('ALE/Breakout-v5', render_mode='rgb_array')
# Resize to smaller resolution
env = ResizeObservation(env, (84, 84))
# Convert to grayscale
env = GrayscaleObservation(env, keep_dim=True)
# Stack 4 frames for temporal information
env = FrameStackObservation(env, stack_size=4)
# Normalize observations online
env = NormalizeObservation(env)import numpy as np
class RewardScalingWrapper(gym.RewardWrapper):
"""Scale rewards by a constant factor."""
def __init__(self, env, scale=0.1):
super().__init__(env)
self.scale = scale
def reward(self, reward):
return reward * self.scale
class NoopResetWrapper(gym.Wrapper):
"""Add random number of no-op actions at episode start."""
def __init__(self, env, noop_max=30):
super().__init__(env)
self.noop_max = noop_max
self.noop_action = 0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
# Execute random number of no-op actions
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
# Usage
env = gym.make('ALE/Breakout-v5')
env = NoopResetWrapper(env, noop_max=10)
env = RewardScalingWrapper(env, scale=0.1)# Check wrapper hierarchy
env = gym.make('CartPole-v1')
env = TimeLimit(env, max_episode_steps=200)
env = RecordEpisodeStatistics(env)
print(env) # Shows wrapper stack
print(env.unwrapped) # Access original environment
# Access wrapped environment attributes
print(env.unwrapped.spec.max_episode_steps)
# Remove specific wrapper types
def remove_wrapper(env, wrapper_class):
"""Remove specific wrapper from wrapper stack."""
if isinstance(env, wrapper_class):
return env.env
elif hasattr(env, 'env'):
env.env = remove_wrapper(env.env, wrapper_class)
return env
env_without_timelimit = remove_wrapper(env, TimeLimit)Install with Tessl CLI
npx tessl i tessl/pypi-gymnasium