or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

composer.mdindex.mdmjcf.mdphysics.mdsuite.mdviewer.md

composer.mddocs/

0

# Environment Composition

1

2

Framework for programmatically building complex reinforcement learning environments by combining entities, arenas, and tasks. Enables modular environment design with reusable components, flexible composition patterns, and comprehensive observation systems.

3

4

## Capabilities

5

6

### Environment Building

7

8

Core environment class for composer-based RL environments.

9

10

```python { .api }

11

class Environment:

12

"""

13

Composer environment for custom RL tasks.

14

15

Provides full RL environment interface with modular composition

16

of entities, arenas, and tasks.

17

"""

18

19

def __init__(self, task: 'Task', arena: 'Arena' = None,

20

time_limit: float = float('inf'),

21

random_state: np.random.RandomState = None):

22

"""

23

Initialize composer environment.

24

25

Parameters:

26

- task: Task instance defining objectives and rewards

27

- arena: Optional arena for environment layout (default: task arena)

28

- time_limit: Episode time limit in seconds (default: infinite)

29

- random_state: Random state for reproducibility

30

"""

31

32

def reset(self) -> 'TimeStep':

33

"""

34

Reset environment and return initial timestep.

35

36

Returns:

37

Initial TimeStep with observations

38

"""

39

40

def step(self, action) -> 'TimeStep':

41

"""

42

Apply action and advance environment.

43

44

Parameters:

45

- action: Action conforming to action_spec()

46

47

Returns:

48

TimeStep with new observations and rewards

49

"""

50

51

def action_spec(self) -> 'BoundedArraySpec':

52

"""

53

Get action specification.

54

55

Returns:

56

Specification describing valid actions

57

"""

58

59

def observation_spec(self) -> dict:

60

"""

61

Get observation specification.

62

63

Returns:

64

Dict mapping observation names to specs

65

"""

66

67

class EpisodeInitializationError(Exception):

68

"""Error raised during episode initialization."""

69

pass

70

71

# Environment hooks

72

HOOK_NAMES: tuple

73

"""Names of available environment hooks for customization."""

74

75

class ObservationPadding:

76

"""Utilities for padding observations to consistent shapes."""

77

pass

78

```

79

80

### Entity System

81

82

Base classes for all environment entities with observables and physics integration.

83

84

```python { .api }

85

class Entity:

86

"""

87

Base class for all environment entities.

88

89

Entities represent physical objects, agents, or abstract components

90

that can be composed into environments.

91

"""

92

93

def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:

94

"""

95

Initialize entity for new episode.

96

97

Parameters:

98

- physics: Physics instance for the episode

99

- random_state: Random state for stochastic initialization

100

"""

101

102

def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:

103

"""

104

Called before physics step.

105

106

Parameters:

107

- physics: Current physics state

108

- action: Action being applied

109

- random_state: Random state

110

"""

111

112

def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:

113

"""

114

Called after physics step.

115

116

Parameters:

117

- physics: Updated physics state

118

- random_state: Random state

119

"""

120

121

@property

122

def mjcf_model(self) -> 'RootElement':

123

"""MJCF model for this entity."""

124

125

@property

126

def observables(self) -> 'Observables':

127

"""Observable quantities for this entity."""

128

129

class ModelWrapperEntity(Entity):

130

"""

131

Entity that wraps an existing MJCF model.

132

133

Provides Entity interface for pre-existing MJCF models.

134

"""

135

136

def __init__(self, mjcf_model: 'RootElement'):

137

"""

138

Initialize with MJCF model.

139

140

Parameters:

141

- mjcf_model: MJCF model to wrap

142

"""

143

144

class FreePropObservableMixin:

145

"""Mixin for entities with free-floating observables."""

146

pass

147

148

class Robot(Entity):

149

"""

150

Base class for robotic entities.

151

152

Specialized entity for robots with actuation, sensing,

153

and control interfaces.

154

"""

155

156

@property

157

def actuators(self) -> list:

158

"""List of actuator elements."""

159

160

@property

161

def joints(self) -> list:

162

"""List of joint elements."""

163

```

164

165

### Arena System

166

167

Base classes for environment layouts and spatial organization.

168

169

```python { .api }

170

class Arena(Entity):

171

"""

172

Base class for environment arenas.

173

174

Arenas define the spatial layout and structure of environments,

175

providing surfaces, boundaries, and spatial organization.

176

"""

177

178

@property

179

def ground_geoms(self) -> list:

180

"""Ground geometry elements."""

181

182

def regenerate(self, random_state: np.random.RandomState) -> None:

183

"""

184

Regenerate arena layout.

185

186

Parameters:

187

- random_state: Random state for stochastic generation

188

"""

189

190

def add_entity(self, entity: 'Entity', attachment_frame: 'Element' = None) -> None:

191

"""

192

Add entity to arena.

193

194

Parameters:

195

- entity: Entity to add

196

- attachment_frame: Optional attachment point

197

"""

198

```

199

200

### Task System

201

202

Base classes for defining RL objectives and reward functions.

203

204

```python { .api }

205

class Task:

206

"""

207

Base class for RL tasks.

208

209

Tasks define objectives, reward functions, termination conditions,

210

and episode initialization for RL environments.

211

"""

212

213

def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:

214

"""

215

Initialize task for new episode.

216

217

Parameters:

218

- physics: Physics instance

219

- random_state: Random state

220

"""

221

222

def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:

223

"""

224

Called before physics step.

225

226

Parameters:

227

- physics: Current physics state

228

- action: Action being applied

229

- random_state: Random state

230

"""

231

232

def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:

233

"""

234

Called after physics step.

235

236

Parameters:

237

- physics: Updated physics state

238

- random_state: Random state

239

"""

240

241

def get_reward(self, physics: 'Physics') -> float:

242

"""

243

Calculate reward for current state.

244

245

Parameters:

246

- physics: Current physics state

247

248

Returns:

249

Scalar reward value

250

"""

251

252

def get_termination(self, physics: 'Physics') -> bool:

253

"""

254

Check if episode should terminate.

255

256

Parameters:

257

- physics: Current physics state

258

259

Returns:

260

True if episode should end

261

"""

262

263

def get_discount(self, physics: 'Physics') -> float:

264

"""

265

Get discount factor for current step.

266

267

Parameters:

268

- physics: Current physics state

269

270

Returns:

271

Discount factor (typically 1.0 or 0.0)

272

"""

273

274

@property

275

def observables(self) -> 'Observables':

276

"""Observable quantities for this task."""

277

278

@property

279

def control_timestep(self) -> float:

280

"""Control timestep duration."""

281

282

class NullTask(Task):

283

"""Task with no objectives - useful for free exploration."""

284

pass

285

```

286

287

### Observable System

288

289

System for defining and managing observable quantities.

290

291

```python { .api }

292

class Observables:

293

"""

294

Collection of observable quantities with specifications.

295

296

Manages named observables with automatic specification generation

297

and value extraction from physics.

298

"""

299

300

def add_observable(self, name: str, observable_callable: callable) -> None:

301

"""

302

Add named observable.

303

304

Parameters:

305

- name: Observable name

306

- observable_callable: Function returning observable value

307

"""

308

309

def get_observation(self, physics: 'Physics') -> dict:

310

"""

311

Extract all observable values.

312

313

Parameters:

314

- physics: Physics instance

315

316

Returns:

317

Dict mapping observable names to values

318

"""

319

320

@observable

321

def observable(func: callable) -> callable:

322

"""

323

Decorator for marking methods as observable.

324

325

Parameters:

326

- func: Method to mark as observable

327

328

Returns:

329

Decorated method with observable metadata

330

331

Example:

332

>>> @observable

333

... def joint_positions(self, physics):

334

... return physics.named.data.qpos[self.joints]

335

"""

336

337

@cached_property

338

def cached_property(func: callable) -> property:

339

"""

340

Decorator for cached property computation.

341

342

Parameters:

343

- func: Method to cache

344

345

Returns:

346

Property that caches result after first access

347

348

Example:

349

>>> @cached_property

350

... def joint_names(self):

351

... return [joint.name for joint in self.joints]

352

"""

353

```

354

355

### Initialization System

356

357

Base classes for entity initialization strategies.

358

359

```python { .api }

360

class Initializer:

361

"""

362

Base class for initialization strategies.

363

364

Initializers define how entities should be positioned and configured

365

at the start of each episode.

366

"""

367

368

def __call__(self, physics: 'Physics', random_state: np.random.RandomState, entity: 'Entity') -> None:

369

"""

370

Initialize entity in physics.

371

372

Parameters:

373

- physics: Physics instance

374

- random_state: Random state

375

- entity: Entity to initialize

376

"""

377

```

378

379

## Usage Examples

380

381

### Creating Custom Environments

382

383

```python

384

from dm_control import composer

385

from dm_control import mjcf

386

import numpy as np

387

388

# Create custom task

389

class ReachTask(composer.Task):

390

def __init__(self, target_position):

391

self.target_position = target_position

392

393

def initialize_episode(self, physics, random_state):

394

# Randomize target position

395

self.target_position = random_state.uniform(-1, 1, size=3)

396

397

def get_reward(self, physics):

398

# Simple distance-based reward

399

hand_pos = physics.named.data.site_xpos['hand_site']

400

distance = np.linalg.norm(hand_pos - self.target_position)

401

return np.exp(-distance)

402

403

@composer.observable

404

def target_position_obs(self, physics):

405

return self.target_position

406

407

# Create custom arena

408

class SimpleArena(composer.Arena):

409

def _build(self):

410

self.mjcf_model.worldbody.add('geom',

411

type='plane', size=[2, 2, 0.1], rgba=[0.5, 0.5, 0.5, 1])

412

413

# Create environment

414

arena = SimpleArena()

415

task = ReachTask(target_position=[0.5, 0.5, 0.5])

416

env = composer.Environment(task=task, arena=arena, time_limit=10.0)

417

```

418

419

### Entity Composition

420

421

```python

422

# Load robot entity

423

robot_model = mjcf.from_path('/path/to/robot.xml')

424

robot = composer.ModelWrapperEntity(robot_model)

425

426

# Create observable for joint positions

427

@composer.observable

428

def joint_positions(physics):

429

return physics.named.data.qpos[robot.joints]

430

431

# Add observable to robot

432

robot.observables.add_observable('joint_pos', joint_positions)

433

434

# Create custom entity

435

class Ball(composer.Entity):

436

def _build(self):

437

self.mjcf_model.worldbody.add('body', name='ball').add(

438

'geom', type='sphere', size=[0.05], rgba=[1, 0, 0, 1])

439

440

@composer.observable

441

def position(self, physics):

442

return physics.named.data.xpos['ball']

443

444

ball = Ball()

445

```

446

447

### Advanced Task Design

448

449

```python

450

class MultiObjectiveTask(composer.Task):

451

def __init__(self, robots, targets):

452

self.robots = robots

453

self.targets = targets

454

self.weights = [1.0, 0.5, 0.2] # Objective weights

455

456

def get_reward(self, physics):

457

rewards = []

458

459

# Primary objective: reach target

460

for robot, target in zip(self.robots, self.targets):

461

hand_pos = physics.named.data.site_xpos[f'{robot.name}_hand']

462

distance = np.linalg.norm(hand_pos - target)

463

rewards.append(np.exp(-distance))

464

465

# Secondary objective: energy efficiency

466

control_cost = np.sum(physics.data.ctrl ** 2)

467

rewards.append(-0.1 * control_cost)

468

469

# Tertiary objective: smoothness

470

velocity_cost = np.sum(physics.named.data.qvel ** 2)

471

rewards.append(-0.01 * velocity_cost)

472

473

return np.dot(rewards, self.weights)

474

475

@composer.observable

476

def objective_values(self, physics):

477

# Return individual objective values for analysis

478

return np.array([self.get_reward(physics)])

479

```

480

481

### Observable Management

482

483

```python

484

class SensorEntity(composer.Entity):

485

def _build(self):

486

# Add sensors to model

487

self.mjcf_model.sensor.add('accelerometer',

488

name='accel', site='sensor_site')

489

self.mjcf_model.sensor.add('gyro',

490

name='gyro', site='sensor_site')

491

492

@composer.observable

493

def acceleration(self, physics):

494

return physics.named.data.sensordata['accel']

495

496

@composer.observable

497

def angular_velocity(self, physics):

498

return physics.named.data.sensordata['gyro']

499

500

@composer.cached_property

501

def sensor_site(self):

502

return self.mjcf_model.find('site', name='sensor_site')

503

504

# Use in environment

505

sensor_entity = SensorEntity()

506

task = composer.NullTask()

507

env = composer.Environment(task=task)

508

509

# Access observations

510

time_step = env.reset()

511

accel_obs = time_step.observation['acceleration']

512

gyro_obs = time_step.observation['angular_velocity']

513

```