Continuous control environments and MuJoCo Python bindings for physics-based simulation and Reinforcement Learning
—
Framework for programmatically building complex reinforcement learning environments by combining entities, arenas, and tasks. Enables modular environment design with reusable components, flexible composition patterns, and comprehensive observation systems.
Core environment class for composer-based RL environments.
class Environment:
"""
Composer environment for custom RL tasks.
Provides full RL environment interface with modular composition
of entities, arenas, and tasks.
"""
def __init__(self, task: 'Task', arena: 'Arena' = None,
time_limit: float = float('inf'),
random_state: np.random.RandomState = None):
"""
Initialize composer environment.
Parameters:
- task: Task instance defining objectives and rewards
- arena: Optional arena for environment layout (default: task arena)
- time_limit: Episode time limit in seconds (default: infinite)
- random_state: Random state for reproducibility
"""
def reset(self) -> 'TimeStep':
"""
Reset environment and return initial timestep.
Returns:
Initial TimeStep with observations
"""
def step(self, action) -> 'TimeStep':
"""
Apply action and advance environment.
Parameters:
- action: Action conforming to action_spec()
Returns:
TimeStep with new observations and rewards
"""
def action_spec(self) -> 'BoundedArraySpec':
"""
Get action specification.
Returns:
Specification describing valid actions
"""
def observation_spec(self) -> dict:
"""
Get observation specification.
Returns:
Dict mapping observation names to specs
"""
class EpisodeInitializationError(Exception):
"""Error raised during episode initialization."""
pass
# Environment hooks
HOOK_NAMES: tuple
"""Names of available environment hooks for customization."""
class ObservationPadding:
"""Utilities for padding observations to consistent shapes."""
passBase classes for all environment entities with observables and physics integration.
class Entity:
"""
Base class for all environment entities.
Entities represent physical objects, agents, or abstract components
that can be composed into environments.
"""
def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
"""
Initialize entity for new episode.
Parameters:
- physics: Physics instance for the episode
- random_state: Random state for stochastic initialization
"""
def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
"""
Called before physics step.
Parameters:
- physics: Current physics state
- action: Action being applied
- random_state: Random state
"""
def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
"""
Called after physics step.
Parameters:
- physics: Updated physics state
- random_state: Random state
"""
@property
def mjcf_model(self) -> 'RootElement':
"""MJCF model for this entity."""
@property
def observables(self) -> 'Observables':
"""Observable quantities for this entity."""
class ModelWrapperEntity(Entity):
"""
Entity that wraps an existing MJCF model.
Provides Entity interface for pre-existing MJCF models.
"""
def __init__(self, mjcf_model: 'RootElement'):
"""
Initialize with MJCF model.
Parameters:
- mjcf_model: MJCF model to wrap
"""
class FreePropObservableMixin:
"""Mixin for entities with free-floating observables."""
pass
class Robot(Entity):
"""
Base class for robotic entities.
Specialized entity for robots with actuation, sensing,
and control interfaces.
"""
@property
def actuators(self) -> list:
"""List of actuator elements."""
@property
def joints(self) -> list:
"""List of joint elements."""Base classes for environment layouts and spatial organization.
class Arena(Entity):
"""
Base class for environment arenas.
Arenas define the spatial layout and structure of environments,
providing surfaces, boundaries, and spatial organization.
"""
@property
def ground_geoms(self) -> list:
"""Ground geometry elements."""
def regenerate(self, random_state: np.random.RandomState) -> None:
"""
Regenerate arena layout.
Parameters:
- random_state: Random state for stochastic generation
"""
def add_entity(self, entity: 'Entity', attachment_frame: 'Element' = None) -> None:
"""
Add entity to arena.
Parameters:
- entity: Entity to add
- attachment_frame: Optional attachment point
"""Base classes for defining RL objectives and reward functions.
class Task:
"""
Base class for RL tasks.
Tasks define objectives, reward functions, termination conditions,
and episode initialization for RL environments.
"""
def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
"""
Initialize task for new episode.
Parameters:
- physics: Physics instance
- random_state: Random state
"""
def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
"""
Called before physics step.
Parameters:
- physics: Current physics state
- action: Action being applied
- random_state: Random state
"""
def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
"""
Called after physics step.
Parameters:
- physics: Updated physics state
- random_state: Random state
"""
def get_reward(self, physics: 'Physics') -> float:
"""
Calculate reward for current state.
Parameters:
- physics: Current physics state
Returns:
Scalar reward value
"""
def get_termination(self, physics: 'Physics') -> bool:
"""
Check if episode should terminate.
Parameters:
- physics: Current physics state
Returns:
True if episode should end
"""
def get_discount(self, physics: 'Physics') -> float:
"""
Get discount factor for current step.
Parameters:
- physics: Current physics state
Returns:
Discount factor (typically 1.0 or 0.0)
"""
@property
def observables(self) -> 'Observables':
"""Observable quantities for this task."""
@property
def control_timestep(self) -> float:
"""Control timestep duration."""
class NullTask(Task):
"""Task with no objectives - useful for free exploration."""
passSystem for defining and managing observable quantities.
class Observables:
"""
Collection of observable quantities with specifications.
Manages named observables with automatic specification generation
and value extraction from physics.
"""
def add_observable(self, name: str, observable_callable: callable) -> None:
"""
Add named observable.
Parameters:
- name: Observable name
- observable_callable: Function returning observable value
"""
def get_observation(self, physics: 'Physics') -> dict:
"""
Extract all observable values.
Parameters:
- physics: Physics instance
Returns:
Dict mapping observable names to values
"""
@observable
def observable(func: callable) -> callable:
"""
Decorator for marking methods as observable.
Parameters:
- func: Method to mark as observable
Returns:
Decorated method with observable metadata
Example:
>>> @observable
... def joint_positions(self, physics):
... return physics.named.data.qpos[self.joints]
"""
@cached_property
def cached_property(func: callable) -> property:
"""
Decorator for cached property computation.
Parameters:
- func: Method to cache
Returns:
Property that caches result after first access
Example:
>>> @cached_property
... def joint_names(self):
... return [joint.name for joint in self.joints]
"""Base classes for entity initialization strategies.
class Initializer:
"""
Base class for initialization strategies.
Initializers define how entities should be positioned and configured
at the start of each episode.
"""
def __call__(self, physics: 'Physics', random_state: np.random.RandomState, entity: 'Entity') -> None:
"""
Initialize entity in physics.
Parameters:
- physics: Physics instance
- random_state: Random state
- entity: Entity to initialize
"""from dm_control import composer
from dm_control import mjcf
import numpy as np
# Create custom task
class ReachTask(composer.Task):
def __init__(self, target_position):
self.target_position = target_position
def initialize_episode(self, physics, random_state):
# Randomize target position
self.target_position = random_state.uniform(-1, 1, size=3)
def get_reward(self, physics):
# Simple distance-based reward
hand_pos = physics.named.data.site_xpos['hand_site']
distance = np.linalg.norm(hand_pos - self.target_position)
return np.exp(-distance)
@composer.observable
def target_position_obs(self, physics):
return self.target_position
# Create custom arena
class SimpleArena(composer.Arena):
def _build(self):
self.mjcf_model.worldbody.add('geom',
type='plane', size=[2, 2, 0.1], rgba=[0.5, 0.5, 0.5, 1])
# Create environment
arena = SimpleArena()
task = ReachTask(target_position=[0.5, 0.5, 0.5])
env = composer.Environment(task=task, arena=arena, time_limit=10.0)# Load robot entity
robot_model = mjcf.from_path('/path/to/robot.xml')
robot = composer.ModelWrapperEntity(robot_model)
# Create observable for joint positions
@composer.observable
def joint_positions(physics):
return physics.named.data.qpos[robot.joints]
# Add observable to robot
robot.observables.add_observable('joint_pos', joint_positions)
# Create custom entity
class Ball(composer.Entity):
def _build(self):
self.mjcf_model.worldbody.add('body', name='ball').add(
'geom', type='sphere', size=[0.05], rgba=[1, 0, 0, 1])
@composer.observable
def position(self, physics):
return physics.named.data.xpos['ball']
ball = Ball()class MultiObjectiveTask(composer.Task):
def __init__(self, robots, targets):
self.robots = robots
self.targets = targets
self.weights = [1.0, 0.5, 0.2] # Objective weights
def get_reward(self, physics):
rewards = []
# Primary objective: reach target
for robot, target in zip(self.robots, self.targets):
hand_pos = physics.named.data.site_xpos[f'{robot.name}_hand']
distance = np.linalg.norm(hand_pos - target)
rewards.append(np.exp(-distance))
# Secondary objective: energy efficiency
control_cost = np.sum(physics.data.ctrl ** 2)
rewards.append(-0.1 * control_cost)
# Tertiary objective: smoothness
velocity_cost = np.sum(physics.named.data.qvel ** 2)
rewards.append(-0.01 * velocity_cost)
return np.dot(rewards, self.weights)
@composer.observable
def objective_values(self, physics):
# Return individual objective values for analysis
return np.array([self.get_reward(physics)])class SensorEntity(composer.Entity):
def _build(self):
# Add sensors to model
self.mjcf_model.sensor.add('accelerometer',
name='accel', site='sensor_site')
self.mjcf_model.sensor.add('gyro',
name='gyro', site='sensor_site')
@composer.observable
def acceleration(self, physics):
return physics.named.data.sensordata['accel']
@composer.observable
def angular_velocity(self, physics):
return physics.named.data.sensordata['gyro']
@composer.cached_property
def sensor_site(self):
return self.mjcf_model.find('site', name='sensor_site')
# Use in environment
sensor_entity = SensorEntity()
task = composer.NullTask()
env = composer.Environment(task=task)
# Access observations
time_step = env.reset()
accel_obs = time_step.observation['acceleration']
gyro_obs = time_step.observation['angular_velocity']Install with Tessl CLI
npx tessl i tessl/pypi-dm-control