CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-ray

Ray is a unified framework for scaling AI and Python applications.

Pending
Overview
Eval results
Files

reinforcement-learning.mddocs/

Reinforcement Learning

Ray RLlib provides reinforcement learning algorithms and environments with support for distributed training and various RL frameworks. It includes implementations of state-of-the-art RL algorithms and tools for custom environment development.

Capabilities

Core RL Framework

Base reinforcement learning functionality and algorithm management.

class Policy:
    """Base class for RL policies."""
    
    def compute_actions(self, obs_batch, state_batches=None, 
                       prev_action_batch=None, prev_reward_batch=None,
                       info_batch=None, episodes=None, **kwargs):
        """
        Compute actions for a batch of observations.
        
        Args:
            obs_batch: Batch of observations
            state_batches (list, optional): List of RNN state batches
            prev_action_batch: Previous actions
            prev_reward_batch: Previous rewards
            info_batch: Info dictionaries
            episodes: Episode objects
        
        Returns:
            tuple: (actions, state_outs, extra_info)
        """
    
    def compute_actions_from_input_dict(self, input_dict, explore=None,
                                       timestep=None, **kwargs):
        """
        Compute actions from input dictionary.
        
        Args:
            input_dict (dict): Input dictionary with observations
            explore (bool, optional): Whether to explore
            timestep (int, optional): Current timestep
        
        Returns:
            tuple: (actions, state_outs, extra_info)
        """
    
    def learn_on_batch(self, samples):
        """
        Learn from a batch of samples.
        
        Args:
            samples: Batch of training samples
        
        Returns:
            dict: Training statistics
        """
    
    def get_weights(self):
        """
        Get policy weights.
        
        Returns:
            dict: Policy weights
        """
    
    def set_weights(self, weights):
        """
        Set policy weights.
        
        Args:
            weights (dict): Policy weights to set
        """
    
    def export_model(self, export_dir, onnx=None):
        """
        Export policy model.
        
        Args:
            export_dir (str): Directory to export to
            onnx (int, optional): ONNX opset version
        """

class Algorithm:
    """Base class for RL algorithms."""
    
    def __init__(self, config=None, env=None, logger_creator=None):
        """
        Initialize RL algorithm.
        
        Args:
            config (dict, optional): Algorithm configuration
            env: Environment or environment string
            logger_creator: Logger creator function
        """
    
    def train(self):
        """
        Perform one training iteration.
        
        Returns:
            dict: Training results
        """
    
    def evaluate(self, duration_fn=None, evaluation_fn=None):
        """
        Evaluate current policy.
        
        Args:
            duration_fn: Function to determine evaluation duration
            evaluation_fn: Custom evaluation function
        
        Returns:
            dict: Evaluation results
        """
    
    def compute_single_action(self, observation, state=None, 
                             prev_action=None, prev_reward=None,
                             info=None, policy_id="default_policy", 
                             full_fetch=False, explore=None):
        """
        Compute single action from observation.
        
        Args:
            observation: Single observation
            state: RNN state
            prev_action: Previous action
            prev_reward: Previous reward
            info: Info dictionary
            policy_id (str): Policy ID to use
            full_fetch (bool): Whether to return full info
            explore (bool, optional): Whether to explore
        
        Returns:
            Action or tuple with additional info
        """
    
    def save(self, checkpoint_dir=None):
        """
        Save algorithm checkpoint.
        
        Args:
            checkpoint_dir (str, optional): Directory to save to
        
        Returns:
            str: Checkpoint path
        """
    
    def restore(self, checkpoint_path):
        """
        Restore algorithm from checkpoint.
        
        Args:
            checkpoint_path (str): Path to checkpoint
        """
    
    def stop(self):
        """Stop algorithm and cleanup resources."""
    
    def get_policy(self, policy_id="default_policy"):
        """
        Get policy by ID.
        
        Args:
            policy_id (str): Policy ID
        
        Returns:
            Policy: Policy object
        """
    
    def add_policy(self, policy_id, policy_cls, observation_space=None,
                  action_space=None, config=None, policy_state=None):
        """
        Add new policy to algorithm.
        
        Args:
            policy_id (str): Policy ID
            policy_cls: Policy class
            observation_space: Observation space
            action_space: Action space
            config (dict, optional): Policy configuration
            policy_state: Policy state
        """
    
    def remove_policy(self, policy_id):
        """
        Remove policy from algorithm.
        
        Args:
            policy_id (str): Policy ID to remove
        """

Environment Integration

Work with RL environments and wrappers.

class MultiAgentEnv:
    """Base class for multi-agent environments."""
    
    def reset(self):
        """
        Reset environment.
        
        Returns:
            dict: Initial observations for each agent
        """
    
    def step(self, action_dict):
        """
        Step environment with actions.
        
        Args:
            action_dict (dict): Actions for each agent
        
        Returns:
            tuple: (obs_dict, reward_dict, done_dict, info_dict)
        """
    
    def render(self, mode="human"):
        """Render environment."""
    
    def close(self):
        """Close environment."""

def make_multi_agent(env_name_or_creator):
    """
    Create multi-agent version of environment.
    
    Args:
        env_name_or_creator: Environment name or creator function
    
    Returns:
        MultiAgentEnv: Multi-agent environment
    """

class BaseEnv:
    """Base class for vectorized environments."""
    
    def poll(self):
        """
        Poll for completed episodes.
        
        Returns:
            tuple: (obs_dict, reward_dict, done_dict, info_dict, off_policy_actions_dict)
        """
    
    def send_actions(self, action_dict):
        """
        Send actions to environments.
        
        Args:
            action_dict (dict): Actions for each environment
        """
    
    def try_reset(self, env_id):
        """
        Try to reset specific environment.
        
        Args:
            env_id: Environment ID
        
        Returns:
            dict or None: Observation if reset successful
        """

Configuration and Spaces

Configure algorithms and define spaces.

class AlgorithmConfig:
    """Configuration for RL algorithms."""
    
    def __init__(self, algo_class=None):
        """Initialize algorithm configuration."""
    
    def environment(self, env=None, *, env_config=None, observation_space=None,
                   action_space=None, **kwargs):
        """
        Configure environment settings.
        
        Args:
            env: Environment or environment string
            env_config (dict, optional): Environment configuration
            observation_space: Observation space
            action_space: Action space
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def framework(self, framework=None, *, eager_tracing=None, **kwargs):
        """
        Configure ML framework.
        
        Args:
            framework (str, optional): Framework ("tf", "tf2", "torch")
            eager_tracing (bool, optional): Enable eager tracing
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def resources(self, *, num_gpus=None, num_cpus_per_worker=None,
                 num_gpus_per_worker=None, **kwargs):
        """
        Configure resource usage.
        
        Args:
            num_gpus (float, optional): Number of GPUs
            num_cpus_per_worker (float, optional): CPUs per worker
            num_gpus_per_worker (float, optional): GPUs per worker
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def rollouts(self, *, num_rollout_workers=None, num_envs_per_worker=None,
                rollout_fragment_length=None, **kwargs):
        """
        Configure rollout collection.
        
        Args:
            num_rollout_workers (int, optional): Number of rollout workers
            num_envs_per_worker (int, optional): Environments per worker
            rollout_fragment_length (int, optional): Rollout fragment length
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def training(self, *, lr=None, train_batch_size=None, **kwargs):
        """
        Configure training settings.
        
        Args:
            lr (float, optional): Learning rate
            train_batch_size (int, optional): Training batch size
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def evaluation(self, *, evaluation_interval=None, evaluation_duration=None,
                  **kwargs):
        """
        Configure evaluation settings.
        
        Args:
            evaluation_interval (int, optional): Evaluation interval
            evaluation_duration (int, optional): Evaluation duration
        
        Returns:
            AlgorithmConfig: Self for chaining
        """
    
    def build(self, env=None, logger_creator=None):
        """
        Build algorithm from configuration.
        
        Args:
            env: Environment override
            logger_creator: Logger creator override
        
        Returns:
            Algorithm: Built algorithm
        """

Specific RL Algorithms

Implementations of specific RL algorithms.

class PPOConfig(AlgorithmConfig):
    """Configuration for Proximal Policy Optimization."""

class PPO(Algorithm):
    """Proximal Policy Optimization algorithm."""

class SACConfig(AlgorithmConfig):
    """Configuration for Soft Actor-Critic."""

class SAC(Algorithm):
    """Soft Actor-Critic algorithm."""

class DQNConfig(AlgorithmConfig):
    """Configuration for Deep Q-Network."""

class DQN(Algorithm):
    """Deep Q-Network algorithm."""

class A3CConfig(AlgorithmConfig):
    """Configuration for Asynchronous Advantage Actor-Critic."""

class A3C(Algorithm):
    """Asynchronous Advantage Actor-Critic algorithm."""

class IMPALAConfig(AlgorithmConfig):
    """Configuration for IMPALA."""

class IMPALA(Algorithm):
    """IMPALA algorithm."""

Utilities and Helpers

Utility functions for RL development.

def register_env(name, env_creator):
    """
    Register environment with Ray RLlib.
    
    Args:
        name (str): Environment name
        env_creator: Function that creates environment
    """

class ModelCatalog:
    """Catalog for registering custom models and preprocessors."""
    
    @staticmethod
    def register_custom_model(model_name, model_class):
        """
        Register custom model.
        
        Args:
            model_name (str): Model name
            model_class: Model class
        """
    
    @staticmethod
    def register_custom_preprocessor(preprocessor_name, preprocessor_class):
        """
        Register custom preprocessor.
        
        Args:
            preprocessor_name (str): Preprocessor name
            preprocessor_class: Preprocessor class
        """
    
    @staticmethod
    def register_custom_action_dist(action_dist_name, action_dist_class):
        """
        Register custom action distribution.
        
        Args:
            action_dist_name (str): Action distribution name
            action_dist_class: Action distribution class
        """

def rollout(agent, env_name, num_steps=None, num_episodes=1, 
           no_render=False, video_dir=None):
    """
    Rollout trained agent in environment.
    
    Args:
        agent: Trained agent/algorithm
        env_name (str): Environment name
        num_steps (int, optional): Number of steps
        num_episodes (int): Number of episodes
        no_render (bool): Whether to disable rendering
        video_dir (str, optional): Directory to save videos
    
    Returns:
        list: Episode rewards
    """

Usage Examples

Basic RL Training

import ray
from ray.rllib.algorithms.ppo import PPOConfig

# Initialize Ray
ray.init()

# Configure PPO algorithm
config = (PPOConfig()
          .environment(env="CartPole-v1")
          .rollouts(num_rollout_workers=2)
          .training(lr=0.0001, train_batch_size=4000)
          .evaluation(evaluation_interval=10))

# Build algorithm
algo = config.build()

# Training loop
for i in range(100):
    result = algo.train()
    print(f"Iteration {i}: reward={result['episode_reward_mean']}")
    
    # Save checkpoint every 10 iterations
    if i % 10 == 0:
        checkpoint_path = algo.save()
        print(f"Checkpoint saved at {checkpoint_path}")

# Clean up
algo.stop()
ray.shutdown()

Custom Environment

import ray
from ray.rllib.env.env_context import EnvContext
from ray.rllib.algorithms.dqn import DQNConfig
import gym

class CustomEnv(gym.Env):
    def __init__(self, config: EnvContext):
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
        self.config = config
    
    def reset(self):
        return self.observation_space.sample()
    
    def step(self, action):
        obs = self.observation_space.sample()
        reward = 1.0 if action == 1 else 0.0
        done = False
        info = {}
        return obs, reward, done, info

# Register environment
from ray.rllib.utils import register_env
register_env("custom_env", lambda config: CustomEnv(config))

ray.init()

# Train on custom environment
config = (DQNConfig()
          .environment(env="custom_env", env_config={"param": "value"})
          .training(lr=0.001))

algo = config.build()

for i in range(50):
    result = algo.train()
    print(f"Episode reward: {result['episode_reward_mean']}")

algo.stop()

Multi-Agent RL

import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.algorithms.ppo import PPOConfig
import gym

class MultiAgentCustomEnv(MultiAgentEnv):
    def __init__(self, config):
        self.agents = ["agent_1", "agent_2"]
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
    
    def reset(self):
        return {agent: self.observation_space.sample() 
                for agent in self.agents}
    
    def step(self, action_dict):
        obs = {agent: self.observation_space.sample() 
               for agent in self.agents}
        rewards = {agent: 1.0 for agent in self.agents}
        dones = {"__all__": False}
        infos = {agent: {} for agent in self.agents}
        return obs, rewards, dones, infos

register_env("multi_agent_env", lambda _: MultiAgentCustomEnv({}))

ray.init()

config = (PPOConfig()
          .environment(env="multi_agent_env")
          .multi_agent(
              policies={
                  "policy_1": (None, None, None, {}),
                  "policy_2": (None, None, None, {}),
              },
              policy_mapping_fn=lambda agent_id, episode, **kwargs: 
                  "policy_1" if agent_id == "agent_1" else "policy_2"
          ))

algo = config.build()

for i in range(30):
    result = algo.train()
    print(f"Iteration {i}: {result['episode_reward_mean']}")

algo.stop()

Custom Model

import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.ppo import PPOConfig
import torch
import torch.nn as nn

class CustomModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, 
                 model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, 
                             num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        self.shared_layers = nn.Sequential(
            nn.Linear(obs_space.shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )
        
        self.policy_head = nn.Linear(128, num_outputs)
        self.value_head = nn.Linear(128, 1)
        self._value = None
    
    def forward(self, input_dict, state, seq_lens):
        features = self.shared_layers(input_dict["obs"])
        logits = self.policy_head(features)
        self._value = self.value_head(features).squeeze(1)
        return logits, state
    
    def value_function(self):
        return self._value

# Register custom model
ModelCatalog.register_custom_model("custom_model", CustomModel)

ray.init()

config = (PPOConfig()
          .environment(env="CartPole-v1")
          .training(model={"custom_model": "custom_model"}))

algo = config.build()

for i in range(20):
    result = algo.train()
    print(f"Reward: {result['episode_reward_mean']}")

algo.stop()

Loading and Using Trained Agent

import ray
from ray.rllib.algorithms.ppo import PPO
import gym

ray.init()

# Restore trained algorithm
algo = PPO.from_checkpoint("/path/to/checkpoint")

# Create environment for evaluation
env = gym.make("CartPole-v1")

# Run episodes with trained agent
for episode in range(5):
    obs = env.reset()
    done = False
    total_reward = 0
    
    while not done:
        action = algo.compute_single_action(obs)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        env.render()
    
    print(f"Episode {episode}: Total reward = {total_reward}")

env.close()
algo.stop()

Install with Tessl CLI

npx tessl i tessl/pypi-ray

docs

core-distributed.md

data-processing.md

distributed-training.md

hyperparameter-tuning.md

index.md

model-serving.md

reinforcement-learning.md

utilities-advanced.md

tile.json