Ray is a unified framework for scaling AI and Python applications.
—
Ray RLlib provides reinforcement learning algorithms and environments with support for distributed training and various RL frameworks. It includes implementations of state-of-the-art RL algorithms and tools for custom environment development.
Base reinforcement learning functionality and algorithm management.
class Policy:
"""Base class for RL policies."""
def compute_actions(self, obs_batch, state_batches=None,
prev_action_batch=None, prev_reward_batch=None,
info_batch=None, episodes=None, **kwargs):
"""
Compute actions for a batch of observations.
Args:
obs_batch: Batch of observations
state_batches (list, optional): List of RNN state batches
prev_action_batch: Previous actions
prev_reward_batch: Previous rewards
info_batch: Info dictionaries
episodes: Episode objects
Returns:
tuple: (actions, state_outs, extra_info)
"""
def compute_actions_from_input_dict(self, input_dict, explore=None,
timestep=None, **kwargs):
"""
Compute actions from input dictionary.
Args:
input_dict (dict): Input dictionary with observations
explore (bool, optional): Whether to explore
timestep (int, optional): Current timestep
Returns:
tuple: (actions, state_outs, extra_info)
"""
def learn_on_batch(self, samples):
"""
Learn from a batch of samples.
Args:
samples: Batch of training samples
Returns:
dict: Training statistics
"""
def get_weights(self):
"""
Get policy weights.
Returns:
dict: Policy weights
"""
def set_weights(self, weights):
"""
Set policy weights.
Args:
weights (dict): Policy weights to set
"""
def export_model(self, export_dir, onnx=None):
"""
Export policy model.
Args:
export_dir (str): Directory to export to
onnx (int, optional): ONNX opset version
"""
class Algorithm:
"""Base class for RL algorithms."""
def __init__(self, config=None, env=None, logger_creator=None):
"""
Initialize RL algorithm.
Args:
config (dict, optional): Algorithm configuration
env: Environment or environment string
logger_creator: Logger creator function
"""
def train(self):
"""
Perform one training iteration.
Returns:
dict: Training results
"""
def evaluate(self, duration_fn=None, evaluation_fn=None):
"""
Evaluate current policy.
Args:
duration_fn: Function to determine evaluation duration
evaluation_fn: Custom evaluation function
Returns:
dict: Evaluation results
"""
def compute_single_action(self, observation, state=None,
prev_action=None, prev_reward=None,
info=None, policy_id="default_policy",
full_fetch=False, explore=None):
"""
Compute single action from observation.
Args:
observation: Single observation
state: RNN state
prev_action: Previous action
prev_reward: Previous reward
info: Info dictionary
policy_id (str): Policy ID to use
full_fetch (bool): Whether to return full info
explore (bool, optional): Whether to explore
Returns:
Action or tuple with additional info
"""
def save(self, checkpoint_dir=None):
"""
Save algorithm checkpoint.
Args:
checkpoint_dir (str, optional): Directory to save to
Returns:
str: Checkpoint path
"""
def restore(self, checkpoint_path):
"""
Restore algorithm from checkpoint.
Args:
checkpoint_path (str): Path to checkpoint
"""
def stop(self):
"""Stop algorithm and cleanup resources."""
def get_policy(self, policy_id="default_policy"):
"""
Get policy by ID.
Args:
policy_id (str): Policy ID
Returns:
Policy: Policy object
"""
def add_policy(self, policy_id, policy_cls, observation_space=None,
action_space=None, config=None, policy_state=None):
"""
Add new policy to algorithm.
Args:
policy_id (str): Policy ID
policy_cls: Policy class
observation_space: Observation space
action_space: Action space
config (dict, optional): Policy configuration
policy_state: Policy state
"""
def remove_policy(self, policy_id):
"""
Remove policy from algorithm.
Args:
policy_id (str): Policy ID to remove
"""Work with RL environments and wrappers.
class MultiAgentEnv:
"""Base class for multi-agent environments."""
def reset(self):
"""
Reset environment.
Returns:
dict: Initial observations for each agent
"""
def step(self, action_dict):
"""
Step environment with actions.
Args:
action_dict (dict): Actions for each agent
Returns:
tuple: (obs_dict, reward_dict, done_dict, info_dict)
"""
def render(self, mode="human"):
"""Render environment."""
def close(self):
"""Close environment."""
def make_multi_agent(env_name_or_creator):
"""
Create multi-agent version of environment.
Args:
env_name_or_creator: Environment name or creator function
Returns:
MultiAgentEnv: Multi-agent environment
"""
class BaseEnv:
"""Base class for vectorized environments."""
def poll(self):
"""
Poll for completed episodes.
Returns:
tuple: (obs_dict, reward_dict, done_dict, info_dict, off_policy_actions_dict)
"""
def send_actions(self, action_dict):
"""
Send actions to environments.
Args:
action_dict (dict): Actions for each environment
"""
def try_reset(self, env_id):
"""
Try to reset specific environment.
Args:
env_id: Environment ID
Returns:
dict or None: Observation if reset successful
"""Configure algorithms and define spaces.
class AlgorithmConfig:
"""Configuration for RL algorithms."""
def __init__(self, algo_class=None):
"""Initialize algorithm configuration."""
def environment(self, env=None, *, env_config=None, observation_space=None,
action_space=None, **kwargs):
"""
Configure environment settings.
Args:
env: Environment or environment string
env_config (dict, optional): Environment configuration
observation_space: Observation space
action_space: Action space
Returns:
AlgorithmConfig: Self for chaining
"""
def framework(self, framework=None, *, eager_tracing=None, **kwargs):
"""
Configure ML framework.
Args:
framework (str, optional): Framework ("tf", "tf2", "torch")
eager_tracing (bool, optional): Enable eager tracing
Returns:
AlgorithmConfig: Self for chaining
"""
def resources(self, *, num_gpus=None, num_cpus_per_worker=None,
num_gpus_per_worker=None, **kwargs):
"""
Configure resource usage.
Args:
num_gpus (float, optional): Number of GPUs
num_cpus_per_worker (float, optional): CPUs per worker
num_gpus_per_worker (float, optional): GPUs per worker
Returns:
AlgorithmConfig: Self for chaining
"""
def rollouts(self, *, num_rollout_workers=None, num_envs_per_worker=None,
rollout_fragment_length=None, **kwargs):
"""
Configure rollout collection.
Args:
num_rollout_workers (int, optional): Number of rollout workers
num_envs_per_worker (int, optional): Environments per worker
rollout_fragment_length (int, optional): Rollout fragment length
Returns:
AlgorithmConfig: Self for chaining
"""
def training(self, *, lr=None, train_batch_size=None, **kwargs):
"""
Configure training settings.
Args:
lr (float, optional): Learning rate
train_batch_size (int, optional): Training batch size
Returns:
AlgorithmConfig: Self for chaining
"""
def evaluation(self, *, evaluation_interval=None, evaluation_duration=None,
**kwargs):
"""
Configure evaluation settings.
Args:
evaluation_interval (int, optional): Evaluation interval
evaluation_duration (int, optional): Evaluation duration
Returns:
AlgorithmConfig: Self for chaining
"""
def build(self, env=None, logger_creator=None):
"""
Build algorithm from configuration.
Args:
env: Environment override
logger_creator: Logger creator override
Returns:
Algorithm: Built algorithm
"""Implementations of specific RL algorithms.
class PPOConfig(AlgorithmConfig):
"""Configuration for Proximal Policy Optimization."""
class PPO(Algorithm):
"""Proximal Policy Optimization algorithm."""
class SACConfig(AlgorithmConfig):
"""Configuration for Soft Actor-Critic."""
class SAC(Algorithm):
"""Soft Actor-Critic algorithm."""
class DQNConfig(AlgorithmConfig):
"""Configuration for Deep Q-Network."""
class DQN(Algorithm):
"""Deep Q-Network algorithm."""
class A3CConfig(AlgorithmConfig):
"""Configuration for Asynchronous Advantage Actor-Critic."""
class A3C(Algorithm):
"""Asynchronous Advantage Actor-Critic algorithm."""
class IMPALAConfig(AlgorithmConfig):
"""Configuration for IMPALA."""
class IMPALA(Algorithm):
"""IMPALA algorithm."""Utility functions for RL development.
def register_env(name, env_creator):
"""
Register environment with Ray RLlib.
Args:
name (str): Environment name
env_creator: Function that creates environment
"""
class ModelCatalog:
"""Catalog for registering custom models and preprocessors."""
@staticmethod
def register_custom_model(model_name, model_class):
"""
Register custom model.
Args:
model_name (str): Model name
model_class: Model class
"""
@staticmethod
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
"""
Register custom preprocessor.
Args:
preprocessor_name (str): Preprocessor name
preprocessor_class: Preprocessor class
"""
@staticmethod
def register_custom_action_dist(action_dist_name, action_dist_class):
"""
Register custom action distribution.
Args:
action_dist_name (str): Action distribution name
action_dist_class: Action distribution class
"""
def rollout(agent, env_name, num_steps=None, num_episodes=1,
no_render=False, video_dir=None):
"""
Rollout trained agent in environment.
Args:
agent: Trained agent/algorithm
env_name (str): Environment name
num_steps (int, optional): Number of steps
num_episodes (int): Number of episodes
no_render (bool): Whether to disable rendering
video_dir (str, optional): Directory to save videos
Returns:
list: Episode rewards
"""import ray
from ray.rllib.algorithms.ppo import PPOConfig
# Initialize Ray
ray.init()
# Configure PPO algorithm
config = (PPOConfig()
.environment(env="CartPole-v1")
.rollouts(num_rollout_workers=2)
.training(lr=0.0001, train_batch_size=4000)
.evaluation(evaluation_interval=10))
# Build algorithm
algo = config.build()
# Training loop
for i in range(100):
result = algo.train()
print(f"Iteration {i}: reward={result['episode_reward_mean']}")
# Save checkpoint every 10 iterations
if i % 10 == 0:
checkpoint_path = algo.save()
print(f"Checkpoint saved at {checkpoint_path}")
# Clean up
algo.stop()
ray.shutdown()import ray
from ray.rllib.env.env_context import EnvContext
from ray.rllib.algorithms.dqn import DQNConfig
import gym
class CustomEnv(gym.Env):
def __init__(self, config: EnvContext):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
self.config = config
def reset(self):
return self.observation_space.sample()
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0 if action == 1 else 0.0
done = False
info = {}
return obs, reward, done, info
# Register environment
from ray.rllib.utils import register_env
register_env("custom_env", lambda config: CustomEnv(config))
ray.init()
# Train on custom environment
config = (DQNConfig()
.environment(env="custom_env", env_config={"param": "value"})
.training(lr=0.001))
algo = config.build()
for i in range(50):
result = algo.train()
print(f"Episode reward: {result['episode_reward_mean']}")
algo.stop()import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.algorithms.ppo import PPOConfig
import gym
class MultiAgentCustomEnv(MultiAgentEnv):
def __init__(self, config):
self.agents = ["agent_1", "agent_2"]
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
def reset(self):
return {agent: self.observation_space.sample()
for agent in self.agents}
def step(self, action_dict):
obs = {agent: self.observation_space.sample()
for agent in self.agents}
rewards = {agent: 1.0 for agent in self.agents}
dones = {"__all__": False}
infos = {agent: {} for agent in self.agents}
return obs, rewards, dones, infos
register_env("multi_agent_env", lambda _: MultiAgentCustomEnv({}))
ray.init()
config = (PPOConfig()
.environment(env="multi_agent_env")
.multi_agent(
policies={
"policy_1": (None, None, None, {}),
"policy_2": (None, None, None, {}),
},
policy_mapping_fn=lambda agent_id, episode, **kwargs:
"policy_1" if agent_id == "agent_1" else "policy_2"
))
algo = config.build()
for i in range(30):
result = algo.train()
print(f"Iteration {i}: {result['episode_reward_mean']}")
algo.stop()import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.ppo import PPOConfig
import torch
import torch.nn as nn
class CustomModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs,
model_config, name):
TorchModelV2.__init__(self, obs_space, action_space,
num_outputs, model_config, name)
nn.Module.__init__(self)
self.shared_layers = nn.Sequential(
nn.Linear(obs_space.shape[0], 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
)
self.policy_head = nn.Linear(128, num_outputs)
self.value_head = nn.Linear(128, 1)
self._value = None
def forward(self, input_dict, state, seq_lens):
features = self.shared_layers(input_dict["obs"])
logits = self.policy_head(features)
self._value = self.value_head(features).squeeze(1)
return logits, state
def value_function(self):
return self._value
# Register custom model
ModelCatalog.register_custom_model("custom_model", CustomModel)
ray.init()
config = (PPOConfig()
.environment(env="CartPole-v1")
.training(model={"custom_model": "custom_model"}))
algo = config.build()
for i in range(20):
result = algo.train()
print(f"Reward: {result['episode_reward_mean']}")
algo.stop()import ray
from ray.rllib.algorithms.ppo import PPO
import gym
ray.init()
# Restore trained algorithm
algo = PPO.from_checkpoint("/path/to/checkpoint")
# Create environment for evaluation
env = gym.make("CartPole-v1")
# Run episodes with trained agent
for episode in range(5):
obs = env.reset()
done = False
total_reward = 0
while not done:
action = algo.compute_single_action(obs)
obs, reward, done, info = env.step(action)
total_reward += reward
env.render()
print(f"Episode {episode}: Total reward = {total_reward}")
env.close()
algo.stop()Install with Tessl CLI
npx tessl i tessl/pypi-ray