0
# Environment Wrappers
1
2
Environment wrappers modify the behavior of existing environments without changing the underlying implementation. Gymnasium provides a comprehensive set of pre-built wrappers for common transformations including observation processing, action modification, reward shaping, and rendering enhancements.
3
4
## Capabilities
5
6
### Common Wrappers
7
8
Fundamental wrappers for basic environment modifications.
9
10
```python { .api }
11
class TimeLimit(Wrapper):
12
"""
13
Add time limit to episodes.
14
15
Args:
16
env: Environment to wrap
17
max_episode_steps: Maximum steps per episode
18
"""
19
20
def __init__(self, env: gym.Env, max_episode_steps: int):
21
pass
22
23
class Autoreset(Wrapper):
24
"""
25
Automatically reset environment when episode ends.
26
27
Args:
28
env: Environment to wrap
29
"""
30
31
def __init__(self, env: gym.Env):
32
pass
33
34
class RecordEpisodeStatistics(Wrapper):
35
"""
36
Record episode statistics (length, reward, time).
37
38
Args:
39
env: Environment to wrap
40
buffer_length: Size of statistics buffers (default: 100)
41
stats_key: Key for storing episode statistics in info dict
42
"""
43
44
def __init__(self, env: gym.Env, buffer_length: int = 100,
45
stats_key: str = "episode"):
46
pass
47
48
class OrderEnforcing(Wrapper):
49
"""
50
Enforce that reset is called before step.
51
52
Args:
53
env: Environment to wrap
54
disable_render_order_enforcing: Disable render order enforcement
55
"""
56
57
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
58
pass
59
60
class PassiveEnvChecker(Wrapper):
61
"""
62
Check environment compliance with Gymnasium API.
63
64
Args:
65
env: Environment to wrap
66
"""
67
68
def __init__(self, env: gym.Env):
69
pass
70
```
71
72
### Observation Wrappers
73
74
Wrappers that transform or modify observations.
75
76
```python { .api }
77
class FlattenObservation(ObservationWrapper):
78
"""
79
Flatten observation space (useful for Dict/Tuple spaces).
80
81
Args:
82
env: Environment to wrap
83
"""
84
85
def __init__(self, env: gym.Env):
86
pass
87
88
class FilterObservation(ObservationWrapper):
89
"""
90
Filter keys from Dict observation space.
91
92
Args:
93
env: Environment to wrap
94
filter_keys: Keys to keep in observation dict
95
"""
96
97
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
98
pass
99
100
class TransformObservation(ObservationWrapper):
101
"""
102
Apply custom transformation to observations.
103
104
Args:
105
env: Environment to wrap
106
func: Function to transform observations
107
observation_space: New observation space after transformation
108
"""
109
110
def __init__(self, env: gym.Env, func: Callable,
111
observation_space: gym.Space | None):
112
pass
113
114
class DtypeObservation(ObservationWrapper):
115
"""
116
Convert observation data type.
117
118
Args:
119
env: Environment to wrap
120
dtype: Target data type
121
"""
122
123
def __init__(self, env: gym.Env, dtype: Any):
124
pass
125
126
class ReshapeObservation(ObservationWrapper):
127
"""
128
Reshape observation arrays.
129
130
Args:
131
env: Environment to wrap
132
shape: New shape for observations
133
"""
134
135
def __init__(self, env: gym.Env, shape: int | tuple[int, ...]):
136
pass
137
138
class ResizeObservation(ObservationWrapper):
139
"""
140
Resize image observations.
141
142
Args:
143
env: Environment to wrap
144
shape: New image shape (height, width)
145
"""
146
147
def __init__(self, env: gym.Env, shape: tuple[int, int]):
148
pass
149
150
class GrayscaleObservation(ObservationWrapper):
151
"""
152
Convert RGB observations to grayscale.
153
154
Args:
155
env: Environment to wrap
156
keep_dim: Whether to keep color dimension
157
"""
158
159
def __init__(self, env: gym.Env, keep_dim: bool = False):
160
pass
161
162
class RescaleObservation(ObservationWrapper):
163
"""
164
Rescale observation values to a target range.
165
166
Args:
167
env: Environment to wrap
168
min_obs: Minimum observation value
169
max_obs: Maximum observation value
170
"""
171
172
def __init__(self, env: gym.Env,
173
min_obs: np.floating | np.integer | np.ndarray,
174
max_obs: np.floating | np.integer | np.ndarray):
175
pass
176
```
177
178
### Stateful Observation Wrappers
179
180
Observation wrappers that maintain internal state.
181
182
```python { .api }
183
class FrameStackObservation(ObservationWrapper):
184
"""
185
Stack multiple consecutive frames.
186
187
Args:
188
env: Environment to wrap
189
stack_size: Number of frames to stack
190
padding_type: Padding type for initial frames ('zero' or 'reset')
191
"""
192
193
def __init__(self, env: gym.Env, stack_size: int,
194
padding_type: str = "zero"):
195
pass
196
197
class DelayObservation(ObservationWrapper):
198
"""
199
Add delay to observations.
200
201
Args:
202
env: Environment to wrap
203
delay: Number of steps to delay observations
204
"""
205
206
def __init__(self, env: gym.Env, delay: int):
207
pass
208
209
class NormalizeObservation(ObservationWrapper):
210
"""
211
Online normalization of observations.
212
213
Args:
214
env: Environment to wrap
215
epsilon: Small constant to avoid division by zero
216
"""
217
218
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
219
pass
220
221
class TimeAwareObservation(ObservationWrapper):
222
"""
223
Add time information to observations.
224
225
Args:
226
env: Environment to wrap
227
flatten: Whether to flatten time into observation
228
"""
229
230
def __init__(self, env: gym.Env, flatten: bool = True):
231
pass
232
233
class MaxAndSkipObservation(ObservationWrapper):
234
"""
235
Max pooling and frame skipping for Atari-style games.
236
237
Args:
238
env: Environment to wrap
239
skip: Number of frames to skip
240
"""
241
242
def __init__(self, env: gym.Env, skip: int = 4):
243
pass
244
```
245
246
### Action Wrappers
247
248
Wrappers that transform or constrain actions.
249
250
```python { .api }
251
class ClipAction(ActionWrapper):
252
"""
253
Clip actions to valid range for Box action spaces.
254
255
Args:
256
env: Environment to wrap
257
"""
258
259
def __init__(self, env: gym.Env):
260
pass
261
262
class RescaleAction(ActionWrapper):
263
"""
264
Rescale actions from one range to another.
265
266
Args:
267
env: Environment to wrap
268
min_action: Minimum action value in new range
269
max_action: Maximum action value in new range
270
"""
271
272
def __init__(self, env: gym.Env,
273
min_action: np.floating | np.integer | np.ndarray,
274
max_action: np.floating | np.integer | np.ndarray):
275
pass
276
277
class TransformAction(ActionWrapper):
278
"""
279
Apply custom transformation to actions.
280
281
Args:
282
env: Environment to wrap
283
func: Function to transform actions
284
action_space: New action space after transformation
285
"""
286
287
def __init__(self, env: gym.Env, func: Callable,
288
action_space: Space | None):
289
pass
290
291
class StickyAction(ActionWrapper):
292
"""
293
Repeat previous action with some probability.
294
295
Args:
296
env: Environment to wrap
297
repeat_action_probability: Probability of repeating action
298
repeat_action_duration: Duration or range for repeating actions
299
"""
300
301
def __init__(self, env: gym.Env, repeat_action_probability: float,
302
repeat_action_duration: int | tuple[int, int] = 1):
303
pass
304
```
305
306
### Reward Wrappers
307
308
Wrappers that modify reward signals.
309
310
```python { .api }
311
class ClipReward(RewardWrapper):
312
"""
313
Clip rewards to a specified range.
314
315
Args:
316
env: Environment to wrap
317
min_reward: Minimum reward value
318
max_reward: Maximum reward value
319
"""
320
321
def __init__(self, env: gym.Env, min_reward: float, max_reward: float):
322
pass
323
324
class TransformReward(RewardWrapper):
325
"""
326
Apply custom transformation to rewards.
327
328
Args:
329
env: Environment to wrap
330
func: Function to transform rewards
331
"""
332
333
def __init__(self, env: gym.Env, func: Callable):
334
pass
335
336
class NormalizeReward(RewardWrapper):
337
"""
338
Online normalization of rewards.
339
340
Args:
341
env: Environment to wrap
342
gamma: Discount factor for reward normalization
343
epsilon: Small constant to avoid division by zero
344
"""
345
346
def __init__(self, env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8):
347
pass
348
```
349
350
### Rendering Wrappers
351
352
Wrappers for modifying environment rendering.
353
354
```python { .api }
355
class RecordVideo(Wrapper):
356
"""
357
Record environment episodes as video files.
358
359
Args:
360
env: Environment to wrap
361
video_folder: Directory to save videos
362
episode_trigger: Function determining which episodes to record
363
step_trigger: Function determining which steps to record
364
video_length: Maximum video length in steps
365
name_prefix: Prefix for video filenames
366
fps: Frames per second for video recording
367
disable_logger: Whether to disable logging
368
gc_trigger: Function determining when to collect garbage
369
"""
370
371
def __init__(self, env: gym.Env, video_folder: str,
372
episode_trigger: Callable[[int], bool] | None = None,
373
step_trigger: Callable[[int], bool] | None = None,
374
video_length: int = 0, name_prefix: str = "rl-video",
375
fps: int | None = None, disable_logger: bool = True,
376
gc_trigger: Callable[[int], bool] | None = lambda episode: True):
377
pass
378
379
class HumanRendering(Wrapper):
380
"""
381
Enable human rendering mode for environments.
382
383
Args:
384
env: Environment to wrap
385
"""
386
387
def __init__(self, env: gym.Env):
388
pass
389
390
class RenderCollection(Wrapper):
391
"""
392
Collect rendered frames from multiple render modes.
393
394
Args:
395
env: Environment to wrap
396
pop_frames: Whether to clear frames after render
397
reset_clean: Whether to clear frames on reset
398
"""
399
400
def __init__(self, env: gym.Env, pop_frames: bool = True,
401
reset_clean: bool = True):
402
pass
403
```
404
405
### Array Conversion Wrappers
406
407
Wrappers for converting between different array libraries (lazy-loaded).
408
409
```python { .api }
410
class NumpyToTorch(Wrapper):
411
"""
412
Convert NumPy arrays to PyTorch tensors.
413
414
Args:
415
env: Environment to wrap
416
device: PyTorch device for tensors
417
"""
418
419
def __init__(self, env: gym.Env, device=None):
420
pass
421
422
class JaxToNumpy(Wrapper):
423
"""
424
Convert JAX arrays to NumPy arrays.
425
426
Args:
427
env: Environment to wrap
428
"""
429
430
def __init__(self, env: gym.Env):
431
pass
432
433
class JaxToTorch(Wrapper):
434
"""
435
Convert JAX arrays to PyTorch tensors.
436
437
Args:
438
env: Environment to wrap
439
device: PyTorch device for tensors
440
"""
441
442
def __init__(self, env: gym.Env, device=None):
443
pass
444
```
445
446
## Usage Examples
447
448
### Basic Wrapper Usage
449
450
```python
451
import gymnasium as gym
452
from gymnasium.wrappers import TimeLimit, FlattenObservation, ClipAction
453
454
# Create base environment
455
env = gym.make('LunarLander-v2')
456
457
# Add time limit
458
env = TimeLimit(env, max_episode_steps=500)
459
460
# Flatten observations if needed
461
env = FlattenObservation(env)
462
463
# Chain multiple wrappers
464
env = gym.make('BipedalWalker-v3')
465
env = ClipAction(env) # Ensure actions are in valid range
466
env = RecordEpisodeStatistics(env) # Track episode stats
467
env = TimeLimit(env, max_episode_steps=1000)
468
```
469
470
### Observation Processing Pipeline
471
472
```python
473
from gymnasium.wrappers import (
474
ResizeObservation, GrayscaleObservation,
475
FrameStackObservation, NormalizeObservation
476
)
477
478
# Create Atari environment with preprocessing pipeline
479
env = gym.make('ALE/Breakout-v5', render_mode='rgb_array')
480
481
# Resize to smaller resolution
482
env = ResizeObservation(env, (84, 84))
483
484
# Convert to grayscale
485
env = GrayscaleObservation(env, keep_dim=True)
486
487
# Stack 4 frames for temporal information
488
env = FrameStackObservation(env, stack_size=4)
489
490
# Normalize observations online
491
env = NormalizeObservation(env)
492
```
493
494
### Custom Wrapper Creation
495
496
```python
497
import numpy as np
498
499
class RewardScalingWrapper(gym.RewardWrapper):
500
"""Scale rewards by a constant factor."""
501
502
def __init__(self, env, scale=0.1):
503
super().__init__(env)
504
self.scale = scale
505
506
def reward(self, reward):
507
return reward * self.scale
508
509
class NoopResetWrapper(gym.Wrapper):
510
"""Add random number of no-op actions at episode start."""
511
512
def __init__(self, env, noop_max=30):
513
super().__init__(env)
514
self.noop_max = noop_max
515
self.noop_action = 0
516
517
def reset(self, **kwargs):
518
obs, info = self.env.reset(**kwargs)
519
520
# Execute random number of no-op actions
521
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
522
for _ in range(noops):
523
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
524
if terminated or truncated:
525
obs, info = self.env.reset(**kwargs)
526
527
return obs, info
528
529
# Usage
530
env = gym.make('ALE/Breakout-v5')
531
env = NoopResetWrapper(env, noop_max=10)
532
env = RewardScalingWrapper(env, scale=0.1)
533
```
534
535
### Wrapper Utilities
536
537
```python
538
# Check wrapper hierarchy
539
env = gym.make('CartPole-v1')
540
env = TimeLimit(env, max_episode_steps=200)
541
env = RecordEpisodeStatistics(env)
542
543
print(env) # Shows wrapper stack
544
print(env.unwrapped) # Access original environment
545
546
# Access wrapped environment attributes
547
print(env.unwrapped.spec.max_episode_steps)
548
549
# Remove specific wrapper types
550
def remove_wrapper(env, wrapper_class):
551
"""Remove specific wrapper from wrapper stack."""
552
if isinstance(env, wrapper_class):
553
return env.env
554
elif hasattr(env, 'env'):
555
env.env = remove_wrapper(env.env, wrapper_class)
556
return env
557
558
env_without_timelimit = remove_wrapper(env, TimeLimit)
559
```