CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-dm-control

Continuous control environments and MuJoCo Python bindings for physics-based simulation and Reinforcement Learning

Pending
Overview
Eval results
Files

composer.mddocs/

Environment Composition

Framework for programmatically building complex reinforcement learning environments by combining entities, arenas, and tasks. Enables modular environment design with reusable components, flexible composition patterns, and comprehensive observation systems.

Capabilities

Environment Building

Core environment class for composer-based RL environments.

class Environment:
    """
    Composer environment for custom RL tasks.
    
    Provides full RL environment interface with modular composition
    of entities, arenas, and tasks.
    """
    
    def __init__(self, task: 'Task', arena: 'Arena' = None, 
                 time_limit: float = float('inf'), 
                 random_state: np.random.RandomState = None):
        """
        Initialize composer environment.
        
        Parameters:
        - task: Task instance defining objectives and rewards
        - arena: Optional arena for environment layout (default: task arena)
        - time_limit: Episode time limit in seconds (default: infinite)
        - random_state: Random state for reproducibility
        """
    
    def reset(self) -> 'TimeStep':
        """
        Reset environment and return initial timestep.
        
        Returns:
        Initial TimeStep with observations
        """
        
    def step(self, action) -> 'TimeStep':
        """
        Apply action and advance environment.
        
        Parameters:
        - action: Action conforming to action_spec()
        
        Returns:
        TimeStep with new observations and rewards
        """
        
    def action_spec(self) -> 'BoundedArraySpec':
        """
        Get action specification.
        
        Returns:
        Specification describing valid actions
        """
        
    def observation_spec(self) -> dict:
        """
        Get observation specification.
        
        Returns:
        Dict mapping observation names to specs
        """

class EpisodeInitializationError(Exception):
    """Error raised during episode initialization."""
    pass

# Environment hooks
HOOK_NAMES: tuple
    """Names of available environment hooks for customization."""

class ObservationPadding:
    """Utilities for padding observations to consistent shapes."""
    pass

Entity System

Base classes for all environment entities with observables and physics integration.

class Entity:
    """
    Base class for all environment entities.
    
    Entities represent physical objects, agents, or abstract components
    that can be composed into environments.
    """
    
    def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
        """
        Initialize entity for new episode.
        
        Parameters:
        - physics: Physics instance for the episode
        - random_state: Random state for stochastic initialization
        """
        
    def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
        """
        Called before physics step.
        
        Parameters:
        - physics: Current physics state
        - action: Action being applied
        - random_state: Random state
        """
        
    def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
        """
        Called after physics step.
        
        Parameters:
        - physics: Updated physics state  
        - random_state: Random state
        """
        
    @property
    def mjcf_model(self) -> 'RootElement':
        """MJCF model for this entity."""
        
    @property
    def observables(self) -> 'Observables':
        """Observable quantities for this entity."""

class ModelWrapperEntity(Entity):
    """
    Entity that wraps an existing MJCF model.
    
    Provides Entity interface for pre-existing MJCF models.
    """
    
    def __init__(self, mjcf_model: 'RootElement'):
        """
        Initialize with MJCF model.
        
        Parameters:
        - mjcf_model: MJCF model to wrap
        """

class FreePropObservableMixin:
    """Mixin for entities with free-floating observables."""
    pass

class Robot(Entity):
    """
    Base class for robotic entities.
    
    Specialized entity for robots with actuation, sensing,
    and control interfaces.
    """
    
    @property
    def actuators(self) -> list:
        """List of actuator elements."""
        
    @property
    def joints(self) -> list:
        """List of joint elements."""

Arena System

Base classes for environment layouts and spatial organization.

class Arena(Entity):
    """
    Base class for environment arenas.
    
    Arenas define the spatial layout and structure of environments,
    providing surfaces, boundaries, and spatial organization.
    """
    
    @property
    def ground_geoms(self) -> list:
        """Ground geometry elements."""
        
    def regenerate(self, random_state: np.random.RandomState) -> None:
        """
        Regenerate arena layout.
        
        Parameters:
        - random_state: Random state for stochastic generation
        """
        
    def add_entity(self, entity: 'Entity', attachment_frame: 'Element' = None) -> None:
        """
        Add entity to arena.
        
        Parameters:
        - entity: Entity to add
        - attachment_frame: Optional attachment point
        """

Task System

Base classes for defining RL objectives and reward functions.

class Task:
    """
    Base class for RL tasks.
    
    Tasks define objectives, reward functions, termination conditions,
    and episode initialization for RL environments.
    """
    
    def initialize_episode(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
        """
        Initialize task for new episode.
        
        Parameters:  
        - physics: Physics instance
        - random_state: Random state
        """
        
    def before_step(self, physics: 'Physics', action, random_state: np.random.RandomState) -> None:
        """
        Called before physics step.
        
        Parameters:
        - physics: Current physics state
        - action: Action being applied  
        - random_state: Random state
        """
        
    def after_step(self, physics: 'Physics', random_state: np.random.RandomState) -> None:
        """
        Called after physics step.
        
        Parameters:
        - physics: Updated physics state
        - random_state: Random state  
        """
        
    def get_reward(self, physics: 'Physics') -> float:
        """
        Calculate reward for current state.
        
        Parameters:
        - physics: Current physics state
        
        Returns:
        Scalar reward value
        """
        
    def get_termination(self, physics: 'Physics') -> bool:
        """
        Check if episode should terminate.
        
        Parameters:
        - physics: Current physics state
        
        Returns:
        True if episode should end
        """
        
    def get_discount(self, physics: 'Physics') -> float:
        """
        Get discount factor for current step.
        
        Parameters:
        - physics: Current physics state
        
        Returns:
        Discount factor (typically 1.0 or 0.0)
        """
        
    @property
    def observables(self) -> 'Observables':
        """Observable quantities for this task."""
        
    @property  
    def control_timestep(self) -> float:
        """Control timestep duration."""

class NullTask(Task):
    """Task with no objectives - useful for free exploration."""
    pass

Observable System

System for defining and managing observable quantities.

class Observables:
    """
    Collection of observable quantities with specifications.
    
    Manages named observables with automatic specification generation
    and value extraction from physics.
    """
    
    def add_observable(self, name: str, observable_callable: callable) -> None:
        """
        Add named observable.
        
        Parameters:
        - name: Observable name
        - observable_callable: Function returning observable value
        """
        
    def get_observation(self, physics: 'Physics') -> dict:
        """
        Extract all observable values.
        
        Parameters:
        - physics: Physics instance
        
        Returns:
        Dict mapping observable names to values
        """

@observable  
def observable(func: callable) -> callable:
    """
    Decorator for marking methods as observable.
    
    Parameters:
    - func: Method to mark as observable
    
    Returns:
    Decorated method with observable metadata
    
    Example:
    >>> @observable
    ... def joint_positions(self, physics):
    ...     return physics.named.data.qpos[self.joints]
    """

@cached_property
def cached_property(func: callable) -> property:
    """
    Decorator for cached property computation.
    
    Parameters:
    - func: Method to cache
    
    Returns: 
    Property that caches result after first access
    
    Example:
    >>> @cached_property
    ... def joint_names(self):
    ...     return [joint.name for joint in self.joints]
    """

Initialization System

Base classes for entity initialization strategies.

class Initializer:
    """
    Base class for initialization strategies.
    
    Initializers define how entities should be positioned and configured
    at the start of each episode.
    """
    
    def __call__(self, physics: 'Physics', random_state: np.random.RandomState, entity: 'Entity') -> None:
        """
        Initialize entity in physics.
        
        Parameters:
        - physics: Physics instance
        - random_state: Random state
        - entity: Entity to initialize
        """

Usage Examples

Creating Custom Environments

from dm_control import composer
from dm_control import mjcf
import numpy as np

# Create custom task
class ReachTask(composer.Task):
    def __init__(self, target_position):
        self.target_position = target_position
        
    def initialize_episode(self, physics, random_state):
        # Randomize target position
        self.target_position = random_state.uniform(-1, 1, size=3)
        
    def get_reward(self, physics):
        # Simple distance-based reward
        hand_pos = physics.named.data.site_xpos['hand_site']
        distance = np.linalg.norm(hand_pos - self.target_position)
        return np.exp(-distance)
        
    @composer.observable
    def target_position_obs(self, physics):
        return self.target_position

# Create custom arena
class SimpleArena(composer.Arena):
    def _build(self):
        self.mjcf_model.worldbody.add('geom', 
            type='plane', size=[2, 2, 0.1], rgba=[0.5, 0.5, 0.5, 1])

# Create environment
arena = SimpleArena()
task = ReachTask(target_position=[0.5, 0.5, 0.5])
env = composer.Environment(task=task, arena=arena, time_limit=10.0)

Entity Composition

# Load robot entity
robot_model = mjcf.from_path('/path/to/robot.xml')
robot = composer.ModelWrapperEntity(robot_model)

# Create observable for joint positions
@composer.observable
def joint_positions(physics):
    return physics.named.data.qpos[robot.joints]

# Add observable to robot
robot.observables.add_observable('joint_pos', joint_positions)

# Create custom entity
class Ball(composer.Entity):
    def _build(self):
        self.mjcf_model.worldbody.add('body', name='ball').add(
            'geom', type='sphere', size=[0.05], rgba=[1, 0, 0, 1])
            
    @composer.observable  
    def position(self, physics):
        return physics.named.data.xpos['ball']

ball = Ball()

Advanced Task Design

class MultiObjectiveTask(composer.Task):
    def __init__(self, robots, targets):
        self.robots = robots
        self.targets = targets
        self.weights = [1.0, 0.5, 0.2]  # Objective weights
        
    def get_reward(self, physics):
        rewards = []
        
        # Primary objective: reach target
        for robot, target in zip(self.robots, self.targets):
            hand_pos = physics.named.data.site_xpos[f'{robot.name}_hand']
            distance = np.linalg.norm(hand_pos - target)
            rewards.append(np.exp(-distance))
            
        # Secondary objective: energy efficiency  
        control_cost = np.sum(physics.data.ctrl ** 2)
        rewards.append(-0.1 * control_cost)
        
        # Tertiary objective: smoothness
        velocity_cost = np.sum(physics.named.data.qvel ** 2)
        rewards.append(-0.01 * velocity_cost)
        
        return np.dot(rewards, self.weights)
        
    @composer.observable
    def objective_values(self, physics):
        # Return individual objective values for analysis
        return np.array([self.get_reward(physics)])

Observable Management

class SensorEntity(composer.Entity):
    def _build(self):
        # Add sensors to model
        self.mjcf_model.sensor.add('accelerometer', 
                                  name='accel', site='sensor_site')
        self.mjcf_model.sensor.add('gyro',
                                  name='gyro', site='sensor_site')
        
    @composer.observable
    def acceleration(self, physics):
        return physics.named.data.sensordata['accel']
        
    @composer.observable  
    def angular_velocity(self, physics):
        return physics.named.data.sensordata['gyro']
        
    @composer.cached_property
    def sensor_site(self):
        return self.mjcf_model.find('site', name='sensor_site')

# Use in environment
sensor_entity = SensorEntity()
task = composer.NullTask()
env = composer.Environment(task=task)

# Access observations
time_step = env.reset()
accel_obs = time_step.observation['acceleration']
gyro_obs = time_step.observation['angular_velocity']

Install with Tessl CLI

npx tessl i tessl/pypi-dm-control

docs

composer.md

index.md

mjcf.md

physics.md

suite.md

viewer.md

tile.json