0
# Reinforcement Learning
1
2
Ray RLlib provides reinforcement learning algorithms and environments with support for distributed training and various RL frameworks. It includes implementations of state-of-the-art RL algorithms and tools for custom environment development.
3
4
## Capabilities
5
6
### Core RL Framework
7
8
Base reinforcement learning functionality and algorithm management.
9
10
```python { .api }
11
class Policy:
12
"""Base class for RL policies."""
13
14
def compute_actions(self, obs_batch, state_batches=None,
15
prev_action_batch=None, prev_reward_batch=None,
16
info_batch=None, episodes=None, **kwargs):
17
"""
18
Compute actions for a batch of observations.
19
20
Args:
21
obs_batch: Batch of observations
22
state_batches (list, optional): List of RNN state batches
23
prev_action_batch: Previous actions
24
prev_reward_batch: Previous rewards
25
info_batch: Info dictionaries
26
episodes: Episode objects
27
28
Returns:
29
tuple: (actions, state_outs, extra_info)
30
"""
31
32
def compute_actions_from_input_dict(self, input_dict, explore=None,
33
timestep=None, **kwargs):
34
"""
35
Compute actions from input dictionary.
36
37
Args:
38
input_dict (dict): Input dictionary with observations
39
explore (bool, optional): Whether to explore
40
timestep (int, optional): Current timestep
41
42
Returns:
43
tuple: (actions, state_outs, extra_info)
44
"""
45
46
def learn_on_batch(self, samples):
47
"""
48
Learn from a batch of samples.
49
50
Args:
51
samples: Batch of training samples
52
53
Returns:
54
dict: Training statistics
55
"""
56
57
def get_weights(self):
58
"""
59
Get policy weights.
60
61
Returns:
62
dict: Policy weights
63
"""
64
65
def set_weights(self, weights):
66
"""
67
Set policy weights.
68
69
Args:
70
weights (dict): Policy weights to set
71
"""
72
73
def export_model(self, export_dir, onnx=None):
74
"""
75
Export policy model.
76
77
Args:
78
export_dir (str): Directory to export to
79
onnx (int, optional): ONNX opset version
80
"""
81
82
class Algorithm:
83
"""Base class for RL algorithms."""
84
85
def __init__(self, config=None, env=None, logger_creator=None):
86
"""
87
Initialize RL algorithm.
88
89
Args:
90
config (dict, optional): Algorithm configuration
91
env: Environment or environment string
92
logger_creator: Logger creator function
93
"""
94
95
def train(self):
96
"""
97
Perform one training iteration.
98
99
Returns:
100
dict: Training results
101
"""
102
103
def evaluate(self, duration_fn=None, evaluation_fn=None):
104
"""
105
Evaluate current policy.
106
107
Args:
108
duration_fn: Function to determine evaluation duration
109
evaluation_fn: Custom evaluation function
110
111
Returns:
112
dict: Evaluation results
113
"""
114
115
def compute_single_action(self, observation, state=None,
116
prev_action=None, prev_reward=None,
117
info=None, policy_id="default_policy",
118
full_fetch=False, explore=None):
119
"""
120
Compute single action from observation.
121
122
Args:
123
observation: Single observation
124
state: RNN state
125
prev_action: Previous action
126
prev_reward: Previous reward
127
info: Info dictionary
128
policy_id (str): Policy ID to use
129
full_fetch (bool): Whether to return full info
130
explore (bool, optional): Whether to explore
131
132
Returns:
133
Action or tuple with additional info
134
"""
135
136
def save(self, checkpoint_dir=None):
137
"""
138
Save algorithm checkpoint.
139
140
Args:
141
checkpoint_dir (str, optional): Directory to save to
142
143
Returns:
144
str: Checkpoint path
145
"""
146
147
def restore(self, checkpoint_path):
148
"""
149
Restore algorithm from checkpoint.
150
151
Args:
152
checkpoint_path (str): Path to checkpoint
153
"""
154
155
def stop(self):
156
"""Stop algorithm and cleanup resources."""
157
158
def get_policy(self, policy_id="default_policy"):
159
"""
160
Get policy by ID.
161
162
Args:
163
policy_id (str): Policy ID
164
165
Returns:
166
Policy: Policy object
167
"""
168
169
def add_policy(self, policy_id, policy_cls, observation_space=None,
170
action_space=None, config=None, policy_state=None):
171
"""
172
Add new policy to algorithm.
173
174
Args:
175
policy_id (str): Policy ID
176
policy_cls: Policy class
177
observation_space: Observation space
178
action_space: Action space
179
config (dict, optional): Policy configuration
180
policy_state: Policy state
181
"""
182
183
def remove_policy(self, policy_id):
184
"""
185
Remove policy from algorithm.
186
187
Args:
188
policy_id (str): Policy ID to remove
189
"""
190
```
191
192
### Environment Integration
193
194
Work with RL environments and wrappers.
195
196
```python { .api }
197
class MultiAgentEnv:
198
"""Base class for multi-agent environments."""
199
200
def reset(self):
201
"""
202
Reset environment.
203
204
Returns:
205
dict: Initial observations for each agent
206
"""
207
208
def step(self, action_dict):
209
"""
210
Step environment with actions.
211
212
Args:
213
action_dict (dict): Actions for each agent
214
215
Returns:
216
tuple: (obs_dict, reward_dict, done_dict, info_dict)
217
"""
218
219
def render(self, mode="human"):
220
"""Render environment."""
221
222
def close(self):
223
"""Close environment."""
224
225
def make_multi_agent(env_name_or_creator):
226
"""
227
Create multi-agent version of environment.
228
229
Args:
230
env_name_or_creator: Environment name or creator function
231
232
Returns:
233
MultiAgentEnv: Multi-agent environment
234
"""
235
236
class BaseEnv:
237
"""Base class for vectorized environments."""
238
239
def poll(self):
240
"""
241
Poll for completed episodes.
242
243
Returns:
244
tuple: (obs_dict, reward_dict, done_dict, info_dict, off_policy_actions_dict)
245
"""
246
247
def send_actions(self, action_dict):
248
"""
249
Send actions to environments.
250
251
Args:
252
action_dict (dict): Actions for each environment
253
"""
254
255
def try_reset(self, env_id):
256
"""
257
Try to reset specific environment.
258
259
Args:
260
env_id: Environment ID
261
262
Returns:
263
dict or None: Observation if reset successful
264
"""
265
```
266
267
### Configuration and Spaces
268
269
Configure algorithms and define spaces.
270
271
```python { .api }
272
class AlgorithmConfig:
273
"""Configuration for RL algorithms."""
274
275
def __init__(self, algo_class=None):
276
"""Initialize algorithm configuration."""
277
278
def environment(self, env=None, *, env_config=None, observation_space=None,
279
action_space=None, **kwargs):
280
"""
281
Configure environment settings.
282
283
Args:
284
env: Environment or environment string
285
env_config (dict, optional): Environment configuration
286
observation_space: Observation space
287
action_space: Action space
288
289
Returns:
290
AlgorithmConfig: Self for chaining
291
"""
292
293
def framework(self, framework=None, *, eager_tracing=None, **kwargs):
294
"""
295
Configure ML framework.
296
297
Args:
298
framework (str, optional): Framework ("tf", "tf2", "torch")
299
eager_tracing (bool, optional): Enable eager tracing
300
301
Returns:
302
AlgorithmConfig: Self for chaining
303
"""
304
305
def resources(self, *, num_gpus=None, num_cpus_per_worker=None,
306
num_gpus_per_worker=None, **kwargs):
307
"""
308
Configure resource usage.
309
310
Args:
311
num_gpus (float, optional): Number of GPUs
312
num_cpus_per_worker (float, optional): CPUs per worker
313
num_gpus_per_worker (float, optional): GPUs per worker
314
315
Returns:
316
AlgorithmConfig: Self for chaining
317
"""
318
319
def rollouts(self, *, num_rollout_workers=None, num_envs_per_worker=None,
320
rollout_fragment_length=None, **kwargs):
321
"""
322
Configure rollout collection.
323
324
Args:
325
num_rollout_workers (int, optional): Number of rollout workers
326
num_envs_per_worker (int, optional): Environments per worker
327
rollout_fragment_length (int, optional): Rollout fragment length
328
329
Returns:
330
AlgorithmConfig: Self for chaining
331
"""
332
333
def training(self, *, lr=None, train_batch_size=None, **kwargs):
334
"""
335
Configure training settings.
336
337
Args:
338
lr (float, optional): Learning rate
339
train_batch_size (int, optional): Training batch size
340
341
Returns:
342
AlgorithmConfig: Self for chaining
343
"""
344
345
def evaluation(self, *, evaluation_interval=None, evaluation_duration=None,
346
**kwargs):
347
"""
348
Configure evaluation settings.
349
350
Args:
351
evaluation_interval (int, optional): Evaluation interval
352
evaluation_duration (int, optional): Evaluation duration
353
354
Returns:
355
AlgorithmConfig: Self for chaining
356
"""
357
358
def build(self, env=None, logger_creator=None):
359
"""
360
Build algorithm from configuration.
361
362
Args:
363
env: Environment override
364
logger_creator: Logger creator override
365
366
Returns:
367
Algorithm: Built algorithm
368
"""
369
```
370
371
### Specific RL Algorithms
372
373
Implementations of specific RL algorithms.
374
375
```python { .api }
376
class PPOConfig(AlgorithmConfig):
377
"""Configuration for Proximal Policy Optimization."""
378
379
class PPO(Algorithm):
380
"""Proximal Policy Optimization algorithm."""
381
382
class SACConfig(AlgorithmConfig):
383
"""Configuration for Soft Actor-Critic."""
384
385
class SAC(Algorithm):
386
"""Soft Actor-Critic algorithm."""
387
388
class DQNConfig(AlgorithmConfig):
389
"""Configuration for Deep Q-Network."""
390
391
class DQN(Algorithm):
392
"""Deep Q-Network algorithm."""
393
394
class A3CConfig(AlgorithmConfig):
395
"""Configuration for Asynchronous Advantage Actor-Critic."""
396
397
class A3C(Algorithm):
398
"""Asynchronous Advantage Actor-Critic algorithm."""
399
400
class IMPALAConfig(AlgorithmConfig):
401
"""Configuration for IMPALA."""
402
403
class IMPALA(Algorithm):
404
"""IMPALA algorithm."""
405
```
406
407
### Utilities and Helpers
408
409
Utility functions for RL development.
410
411
```python { .api }
412
def register_env(name, env_creator):
413
"""
414
Register environment with Ray RLlib.
415
416
Args:
417
name (str): Environment name
418
env_creator: Function that creates environment
419
"""
420
421
class ModelCatalog:
422
"""Catalog for registering custom models and preprocessors."""
423
424
@staticmethod
425
def register_custom_model(model_name, model_class):
426
"""
427
Register custom model.
428
429
Args:
430
model_name (str): Model name
431
model_class: Model class
432
"""
433
434
@staticmethod
435
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
436
"""
437
Register custom preprocessor.
438
439
Args:
440
preprocessor_name (str): Preprocessor name
441
preprocessor_class: Preprocessor class
442
"""
443
444
@staticmethod
445
def register_custom_action_dist(action_dist_name, action_dist_class):
446
"""
447
Register custom action distribution.
448
449
Args:
450
action_dist_name (str): Action distribution name
451
action_dist_class: Action distribution class
452
"""
453
454
def rollout(agent, env_name, num_steps=None, num_episodes=1,
455
no_render=False, video_dir=None):
456
"""
457
Rollout trained agent in environment.
458
459
Args:
460
agent: Trained agent/algorithm
461
env_name (str): Environment name
462
num_steps (int, optional): Number of steps
463
num_episodes (int): Number of episodes
464
no_render (bool): Whether to disable rendering
465
video_dir (str, optional): Directory to save videos
466
467
Returns:
468
list: Episode rewards
469
"""
470
```
471
472
## Usage Examples
473
474
### Basic RL Training
475
476
```python
477
import ray
478
from ray.rllib.algorithms.ppo import PPOConfig
479
480
# Initialize Ray
481
ray.init()
482
483
# Configure PPO algorithm
484
config = (PPOConfig()
485
.environment(env="CartPole-v1")
486
.rollouts(num_rollout_workers=2)
487
.training(lr=0.0001, train_batch_size=4000)
488
.evaluation(evaluation_interval=10))
489
490
# Build algorithm
491
algo = config.build()
492
493
# Training loop
494
for i in range(100):
495
result = algo.train()
496
print(f"Iteration {i}: reward={result['episode_reward_mean']}")
497
498
# Save checkpoint every 10 iterations
499
if i % 10 == 0:
500
checkpoint_path = algo.save()
501
print(f"Checkpoint saved at {checkpoint_path}")
502
503
# Clean up
504
algo.stop()
505
ray.shutdown()
506
```
507
508
### Custom Environment
509
510
```python
511
import ray
512
from ray.rllib.env.env_context import EnvContext
513
from ray.rllib.algorithms.dqn import DQNConfig
514
import gym
515
516
class CustomEnv(gym.Env):
517
def __init__(self, config: EnvContext):
518
self.action_space = gym.spaces.Discrete(2)
519
self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
520
self.config = config
521
522
def reset(self):
523
return self.observation_space.sample()
524
525
def step(self, action):
526
obs = self.observation_space.sample()
527
reward = 1.0 if action == 1 else 0.0
528
done = False
529
info = {}
530
return obs, reward, done, info
531
532
# Register environment
533
from ray.rllib.utils import register_env
534
register_env("custom_env", lambda config: CustomEnv(config))
535
536
ray.init()
537
538
# Train on custom environment
539
config = (DQNConfig()
540
.environment(env="custom_env", env_config={"param": "value"})
541
.training(lr=0.001))
542
543
algo = config.build()
544
545
for i in range(50):
546
result = algo.train()
547
print(f"Episode reward: {result['episode_reward_mean']}")
548
549
algo.stop()
550
```
551
552
### Multi-Agent RL
553
554
```python
555
import ray
556
from ray.rllib.env.multi_agent_env import MultiAgentEnv
557
from ray.rllib.algorithms.ppo import PPOConfig
558
import gym
559
560
class MultiAgentCustomEnv(MultiAgentEnv):
561
def __init__(self, config):
562
self.agents = ["agent_1", "agent_2"]
563
self.action_space = gym.spaces.Discrete(2)
564
self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))
565
566
def reset(self):
567
return {agent: self.observation_space.sample()
568
for agent in self.agents}
569
570
def step(self, action_dict):
571
obs = {agent: self.observation_space.sample()
572
for agent in self.agents}
573
rewards = {agent: 1.0 for agent in self.agents}
574
dones = {"__all__": False}
575
infos = {agent: {} for agent in self.agents}
576
return obs, rewards, dones, infos
577
578
register_env("multi_agent_env", lambda _: MultiAgentCustomEnv({}))
579
580
ray.init()
581
582
config = (PPOConfig()
583
.environment(env="multi_agent_env")
584
.multi_agent(
585
policies={
586
"policy_1": (None, None, None, {}),
587
"policy_2": (None, None, None, {}),
588
},
589
policy_mapping_fn=lambda agent_id, episode, **kwargs:
590
"policy_1" if agent_id == "agent_1" else "policy_2"
591
))
592
593
algo = config.build()
594
595
for i in range(30):
596
result = algo.train()
597
print(f"Iteration {i}: {result['episode_reward_mean']}")
598
599
algo.stop()
600
```
601
602
### Custom Model
603
604
```python
605
import ray
606
from ray.rllib.models import ModelCatalog
607
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
608
from ray.rllib.algorithms.ppo import PPOConfig
609
import torch
610
import torch.nn as nn
611
612
class CustomModel(TorchModelV2, nn.Module):
613
def __init__(self, obs_space, action_space, num_outputs,
614
model_config, name):
615
TorchModelV2.__init__(self, obs_space, action_space,
616
num_outputs, model_config, name)
617
nn.Module.__init__(self)
618
619
self.shared_layers = nn.Sequential(
620
nn.Linear(obs_space.shape[0], 128),
621
nn.ReLU(),
622
nn.Linear(128, 128),
623
nn.ReLU(),
624
)
625
626
self.policy_head = nn.Linear(128, num_outputs)
627
self.value_head = nn.Linear(128, 1)
628
self._value = None
629
630
def forward(self, input_dict, state, seq_lens):
631
features = self.shared_layers(input_dict["obs"])
632
logits = self.policy_head(features)
633
self._value = self.value_head(features).squeeze(1)
634
return logits, state
635
636
def value_function(self):
637
return self._value
638
639
# Register custom model
640
ModelCatalog.register_custom_model("custom_model", CustomModel)
641
642
ray.init()
643
644
config = (PPOConfig()
645
.environment(env="CartPole-v1")
646
.training(model={"custom_model": "custom_model"}))
647
648
algo = config.build()
649
650
for i in range(20):
651
result = algo.train()
652
print(f"Reward: {result['episode_reward_mean']}")
653
654
algo.stop()
655
```
656
657
### Loading and Using Trained Agent
658
659
```python
660
import ray
661
from ray.rllib.algorithms.ppo import PPO
662
import gym
663
664
ray.init()
665
666
# Restore trained algorithm
667
algo = PPO.from_checkpoint("/path/to/checkpoint")
668
669
# Create environment for evaluation
670
env = gym.make("CartPole-v1")
671
672
# Run episodes with trained agent
673
for episode in range(5):
674
obs = env.reset()
675
done = False
676
total_reward = 0
677
678
while not done:
679
action = algo.compute_single_action(obs)
680
obs, reward, done, info = env.step(action)
681
total_reward += reward
682
env.render()
683
684
print(f"Episode {episode}: Total reward = {total_reward}")
685
686
env.close()
687
algo.stop()
688
```