0
# Environment Composition
1
2
Framework for programmatically building complex reinforcement learning environments by combining entities, arenas, and tasks. Enables modular environment design with reusable components, flexible composition patterns, and comprehensive observation systems.
3
4
## Capabilities
5
6
### Environment Building
7
8
Core environment class for composer-based RL environments.
9
10
```python { .api }
11
class Environment:
12
"""
13
Composer environment for custom RL tasks.
14
15
Provides full RL environment interface with modular composition
16
of entities, arenas, and tasks.
17
"""
18
19
def __init__(self, task: 'Task', arena: 'Arena' = None,
20
time_limit: float = float('inf'),
21
random_state: np.random.RandomState = None):
22
"""
23
Initialize composer environment.
24
25
Parameters:
26
- task: Task instance defining objectives and rewards
27
- arena: Optional arena for environment layout (default: task arena)
28
- time_limit: Episode time limit in seconds (default: infinite)
29
- random_state: Random state for reproducibility
30
"""
31
32
def reset(self) -> 'TimeStep':
33
"""
34
Reset environment and return initial timestep.
35
36
Returns:
37
Initial TimeStep with observations
38
"""
39
40
def step(self, action) -> 'TimeStep':
41
"""
42
Apply action and advance environment.
43
44
Parameters:
45
- action: Action conforming to action_spec()
46
47
Returns:
48
TimeStep with new observations and rewards
49
"""
50
51
def action_spec(self) -> 'BoundedArraySpec':
52
"""
53
Get action specification.
54
55
Returns:
56
Specification describing valid actions
57
"""
58
59
def observation_spec(self) -> dict:
60
"""
61
Get observation specification.
62
63
Returns:
64
Dict mapping observation names to specs
65
"""
66
67
class EpisodeInitializationError(Exception):
68
"""Error raised during episode initialization."""
69
pass
70
71
# Environment hooks
72
HOOK_NAMES: tuple
73
"""Names of available environment hooks for customization."""
74
75
class ObservationPadding:
76
"""Utilities for padding observations to consistent shapes."""
77
pass
78
```
79
80
### Entity System
81
82
Base classes for all environment entities with observables and physics integration.
83
84
```python { .api }
85
class Entity:
86
"""
87
Base class for all environment entities.
88
89
Entities represent physical objects, agents, or abstract components
90
that can be composed into environments.
91
"""
92
93
def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
94
"""
95
Initialize entity for new episode.
96
97
Parameters:
98
- physics: Physics instance for the episode
99
- random_state: Random state for stochastic initialization
100
"""
101
102
def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
103
"""
104
Called before physics step.
105
106
Parameters:
107
- physics: Current physics state
108
- action: Action being applied
109
- random_state: Random state
110
"""
111
112
def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
113
"""
114
Called after physics step.
115
116
Parameters:
117
- physics: Updated physics state
118
- random_state: Random state
119
"""
120
121
@property
122
def mjcf_model(self) -> 'RootElement':
123
"""MJCF model for this entity."""
124
125
@property
126
def observables(self) -> 'Observables':
127
"""Observable quantities for this entity."""
128
129
class ModelWrapperEntity(Entity):
130
"""
131
Entity that wraps an existing MJCF model.
132
133
Provides Entity interface for pre-existing MJCF models.
134
"""
135
136
def __init__(self, mjcf_model: 'RootElement'):
137
"""
138
Initialize with MJCF model.
139
140
Parameters:
141
- mjcf_model: MJCF model to wrap
142
"""
143
144
class FreePropObservableMixin:
145
"""Mixin for entities with free-floating observables."""
146
pass
147
148
class Robot(Entity):
149
"""
150
Base class for robotic entities.
151
152
Specialized entity for robots with actuation, sensing,
153
and control interfaces.
154
"""
155
156
@property
157
def actuators(self) -> list:
158
"""List of actuator elements."""
159
160
@property
161
def joints(self) -> list:
162
"""List of joint elements."""
163
```
164
165
### Arena System
166
167
Base classes for environment layouts and spatial organization.
168
169
```python { .api }
170
class Arena(Entity):
171
"""
172
Base class for environment arenas.
173
174
Arenas define the spatial layout and structure of environments,
175
providing surfaces, boundaries, and spatial organization.
176
"""
177
178
@property
179
def ground_geoms(self) -> list:
180
"""Ground geometry elements."""
181
182
def regenerate(self, random_state: np.random.RandomState) -> None:
183
"""
184
Regenerate arena layout.
185
186
Parameters:
187
- random_state: Random state for stochastic generation
188
"""
189
190
def add_entity(self, entity: 'Entity', attachment_frame: 'Element' = None) -> None:
191
"""
192
Add entity to arena.
193
194
Parameters:
195
- entity: Entity to add
196
- attachment_frame: Optional attachment point
197
"""
198
```
199
200
### Task System
201
202
Base classes for defining RL objectives and reward functions.
203
204
```python { .api }
205
class Task:
206
"""
207
Base class for RL tasks.
208
209
Tasks define objectives, reward functions, termination conditions,
210
and episode initialization for RL environments.
211
"""
212
213
def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
214
"""
215
Initialize task for new episode.
216
217
Parameters:
218
- physics: Physics instance
219
- random_state: Random state
220
"""
221
222
def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
223
"""
224
Called before physics step.
225
226
Parameters:
227
- physics: Current physics state
228
- action: Action being applied
229
- random_state: Random state
230
"""
231
232
def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
233
"""
234
Called after physics step.
235
236
Parameters:
237
- physics: Updated physics state
238
- random_state: Random state
239
"""
240
241
def get_reward(self, physics: 'Physics') -> float:
242
"""
243
Calculate reward for current state.
244
245
Parameters:
246
- physics: Current physics state
247
248
Returns:
249
Scalar reward value
250
"""
251
252
def get_termination(self, physics: 'Physics') -> bool:
253
"""
254
Check if episode should terminate.
255
256
Parameters:
257
- physics: Current physics state
258
259
Returns:
260
True if episode should end
261
"""
262
263
def get_discount(self, physics: 'Physics') -> float:
264
"""
265
Get discount factor for current step.
266
267
Parameters:
268
- physics: Current physics state
269
270
Returns:
271
Discount factor (typically 1.0 or 0.0)
272
"""
273
274
@property
275
def observables(self) -> 'Observables':
276
"""Observable quantities for this task."""
277
278
@property
279
def control_timestep(self) -> float:
280
"""Control timestep duration."""
281
282
class NullTask(Task):
283
"""Task with no objectives - useful for free exploration."""
284
pass
285
```
286
287
### Observable System
288
289
System for defining and managing observable quantities.
290
291
```python { .api }
292
class Observables:
293
"""
294
Collection of observable quantities with specifications.
295
296
Manages named observables with automatic specification generation
297
and value extraction from physics.
298
"""
299
300
def add_observable(self, name: str, observable_callable: callable) -> None:
301
"""
302
Add named observable.
303
304
Parameters:
305
- name: Observable name
306
- observable_callable: Function returning observable value
307
"""
308
309
def get_observation(self, physics: 'Physics') -> dict:
310
"""
311
Extract all observable values.
312
313
Parameters:
314
- physics: Physics instance
315
316
Returns:
317
Dict mapping observable names to values
318
"""
319
320
@observable
321
def observable(func: callable) -> callable:
322
"""
323
Decorator for marking methods as observable.
324
325
Parameters:
326
- func: Method to mark as observable
327
328
Returns:
329
Decorated method with observable metadata
330
331
Example:
332
>>> @observable
333
... def joint_positions(self, physics):
334
... return physics.named.data.qpos[self.joints]
335
"""
336
337
@cached_property
338
def cached_property(func: callable) -> property:
339
"""
340
Decorator for cached property computation.
341
342
Parameters:
343
- func: Method to cache
344
345
Returns:
346
Property that caches result after first access
347
348
Example:
349
>>> @cached_property
350
... def joint_names(self):
351
... return [joint.name for joint in self.joints]
352
"""
353
```
354
355
### Initialization System
356
357
Base classes for entity initialization strategies.
358
359
```python { .api }
360
class Initializer:
361
"""
362
Base class for initialization strategies.
363
364
Initializers define how entities should be positioned and configured
365
at the start of each episode.
366
"""
367
368
def __call__(self, physics: 'Physics', random_state: np.random.RandomState, entity: 'Entity') -> None:
369
"""
370
Initialize entity in physics.
371
372
Parameters:
373
- physics: Physics instance
374
- random_state: Random state
375
- entity: Entity to initialize
376
"""
377
```
378
379
## Usage Examples
380
381
### Creating Custom Environments
382
383
```python
384
from dm_control import composer
385
from dm_control import mjcf
386
import numpy as np
387
388
# Create custom task
389
class ReachTask(composer.Task):
390
def __init__(self, target_position):
391
self.target_position = target_position
392
393
def initialize_episode(self, physics, random_state):
394
# Randomize target position
395
self.target_position = random_state.uniform(-1, 1, size=3)
396
397
def get_reward(self, physics):
398
# Simple distance-based reward
399
hand_pos = physics.named.data.site_xpos['hand_site']
400
distance = np.linalg.norm(hand_pos - self.target_position)
401
return np.exp(-distance)
402
403
@composer.observable
404
def target_position_obs(self, physics):
405
return self.target_position
406
407
# Create custom arena
408
class SimpleArena(composer.Arena):
409
def _build(self):
410
self.mjcf_model.worldbody.add('geom',
411
type='plane', size=[2, 2, 0.1], rgba=[0.5, 0.5, 0.5, 1])
412
413
# Create environment
414
arena = SimpleArena()
415
task = ReachTask(target_position=[0.5, 0.5, 0.5])
416
env = composer.Environment(task=task, arena=arena, time_limit=10.0)
417
```
418
419
### Entity Composition
420
421
```python
422
# Load robot entity
423
robot_model = mjcf.from_path('/path/to/robot.xml')
424
robot = composer.ModelWrapperEntity(robot_model)
425
426
# Create observable for joint positions
427
@composer.observable
428
def joint_positions(physics):
429
return physics.named.data.qpos[robot.joints]
430
431
# Add observable to robot
432
robot.observables.add_observable('joint_pos', joint_positions)
433
434
# Create custom entity
435
class Ball(composer.Entity):
436
def _build(self):
437
self.mjcf_model.worldbody.add('body', name='ball').add(
438
'geom', type='sphere', size=[0.05], rgba=[1, 0, 0, 1])
439
440
@composer.observable
441
def position(self, physics):
442
return physics.named.data.xpos['ball']
443
444
ball = Ball()
445
```
446
447
### Advanced Task Design
448
449
```python
450
class MultiObjectiveTask(composer.Task):
451
def __init__(self, robots, targets):
452
self.robots = robots
453
self.targets = targets
454
self.weights = [1.0, 0.5, 0.2] # Objective weights
455
456
def get_reward(self, physics):
457
rewards = []
458
459
# Primary objective: reach target
460
for robot, target in zip(self.robots, self.targets):
461
hand_pos = physics.named.data.site_xpos[f'{robot.name}_hand']
462
distance = np.linalg.norm(hand_pos - target)
463
rewards.append(np.exp(-distance))
464
465
# Secondary objective: energy efficiency
466
control_cost = np.sum(physics.data.ctrl ** 2)
467
rewards.append(-0.1 * control_cost)
468
469
# Tertiary objective: smoothness
470
velocity_cost = np.sum(physics.named.data.qvel ** 2)
471
rewards.append(-0.01 * velocity_cost)
472
473
return np.dot(rewards, self.weights)
474
475
@composer.observable
476
def objective_values(self, physics):
477
# Return individual objective values for analysis
478
return np.array([self.get_reward(physics)])
479
```
480
481
### Observable Management
482
483
```python
484
class SensorEntity(composer.Entity):
485
def _build(self):
486
# Add sensors to model
487
self.mjcf_model.sensor.add('accelerometer',
488
name='accel', site='sensor_site')
489
self.mjcf_model.sensor.add('gyro',
490
name='gyro', site='sensor_site')
491
492
@composer.observable
493
def acceleration(self, physics):
494
return physics.named.data.sensordata['accel']
495
496
@composer.observable
497
def angular_velocity(self, physics):
498
return physics.named.data.sensordata['gyro']
499
500
@composer.cached_property
501
def sensor_site(self):
502
return self.mjcf_model.find('site', name='sensor_site')
503
504
# Use in environment
505
sensor_entity = SensorEntity()
506
task = composer.NullTask()
507
env = composer.Environment(task=task)
508
509
# Access observations
510
time_step = env.reset()
511
accel_obs = time_step.observation['acceleration']
512
gyro_obs = time_step.observation['angular_velocity']
513
```