or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-distributed.mddata-processing.mddistributed-training.mdhyperparameter-tuning.mdindex.mdmodel-serving.mdreinforcement-learning.mdutilities-advanced.md

reinforcement-learning.mddocs/

0

# Reinforcement Learning

1

2

Ray RLlib provides reinforcement learning algorithms and environments with support for distributed training and various RL frameworks. It includes implementations of state-of-the-art RL algorithms and tools for custom environment development.

3

4

## Capabilities

5

6

### Core RL Framework

7

8

Base reinforcement learning functionality and algorithm management.

9

10

```python { .api }

11

class Policy:

12

"""Base class for RL policies."""

13

14

def compute_actions(self, obs_batch, state_batches=None,

15

prev_action_batch=None, prev_reward_batch=None,

16

info_batch=None, episodes=None, **kwargs):

17

"""

18

Compute actions for a batch of observations.

19

20

Args:

21

obs_batch: Batch of observations

22

state_batches (list, optional): List of RNN state batches

23

prev_action_batch: Previous actions

24

prev_reward_batch: Previous rewards

25

info_batch: Info dictionaries

26

episodes: Episode objects

27

28

Returns:

29

tuple: (actions, state_outs, extra_info)

30

"""

31

32

def compute_actions_from_input_dict(self, input_dict, explore=None,

33

timestep=None, **kwargs):

34

"""

35

Compute actions from input dictionary.

36

37

Args:

38

input_dict (dict): Input dictionary with observations

39

explore (bool, optional): Whether to explore

40

timestep (int, optional): Current timestep

41

42

Returns:

43

tuple: (actions, state_outs, extra_info)

44

"""

45

46

def learn_on_batch(self, samples):

47

"""

48

Learn from a batch of samples.

49

50

Args:

51

samples: Batch of training samples

52

53

Returns:

54

dict: Training statistics

55

"""

56

57

def get_weights(self):

58

"""

59

Get policy weights.

60

61

Returns:

62

dict: Policy weights

63

"""

64

65

def set_weights(self, weights):

66

"""

67

Set policy weights.

68

69

Args:

70

weights (dict): Policy weights to set

71

"""

72

73

def export_model(self, export_dir, onnx=None):

74

"""

75

Export policy model.

76

77

Args:

78

export_dir (str): Directory to export to

79

onnx (int, optional): ONNX opset version

80

"""

81

82

class Algorithm:

83

"""Base class for RL algorithms."""

84

85

def __init__(self, config=None, env=None, logger_creator=None):

86

"""

87

Initialize RL algorithm.

88

89

Args:

90

config (dict, optional): Algorithm configuration

91

env: Environment or environment string

92

logger_creator: Logger creator function

93

"""

94

95

def train(self):

96

"""

97

Perform one training iteration.

98

99

Returns:

100

dict: Training results

101

"""

102

103

def evaluate(self, duration_fn=None, evaluation_fn=None):

104

"""

105

Evaluate current policy.

106

107

Args:

108

duration_fn: Function to determine evaluation duration

109

evaluation_fn: Custom evaluation function

110

111

Returns:

112

dict: Evaluation results

113

"""

114

115

def compute_single_action(self, observation, state=None,

116

prev_action=None, prev_reward=None,

117

info=None, policy_id="default_policy",

118

full_fetch=False, explore=None):

119

"""

120

Compute single action from observation.

121

122

Args:

123

observation: Single observation

124

state: RNN state

125

prev_action: Previous action

126

prev_reward: Previous reward

127

info: Info dictionary

128

policy_id (str): Policy ID to use

129

full_fetch (bool): Whether to return full info

130

explore (bool, optional): Whether to explore

131

132

Returns:

133

Action or tuple with additional info

134

"""

135

136

def save(self, checkpoint_dir=None):

137

"""

138

Save algorithm checkpoint.

139

140

Args:

141

checkpoint_dir (str, optional): Directory to save to

142

143

Returns:

144

str: Checkpoint path

145

"""

146

147

def restore(self, checkpoint_path):

148

"""

149

Restore algorithm from checkpoint.

150

151

Args:

152

checkpoint_path (str): Path to checkpoint

153

"""

154

155

def stop(self):

156

"""Stop algorithm and cleanup resources."""

157

158

def get_policy(self, policy_id="default_policy"):

159

"""

160

Get policy by ID.

161

162

Args:

163

policy_id (str): Policy ID

164

165

Returns:

166

Policy: Policy object

167

"""

168

169

def add_policy(self, policy_id, policy_cls, observation_space=None,

170

action_space=None, config=None, policy_state=None):

171

"""

172

Add new policy to algorithm.

173

174

Args:

175

policy_id (str): Policy ID

176

policy_cls: Policy class

177

observation_space: Observation space

178

action_space: Action space

179

config (dict, optional): Policy configuration

180

policy_state: Policy state

181

"""

182

183

def remove_policy(self, policy_id):

184

"""

185

Remove policy from algorithm.

186

187

Args:

188

policy_id (str): Policy ID to remove

189

"""

190

```

191

192

### Environment Integration

193

194

Work with RL environments and wrappers.

195

196

```python { .api }

197

class MultiAgentEnv:

198

"""Base class for multi-agent environments."""

199

200

def reset(self):

201

"""

202

Reset environment.

203

204

Returns:

205

dict: Initial observations for each agent

206

"""

207

208

def step(self, action_dict):

209

"""

210

Step environment with actions.

211

212

Args:

213

action_dict (dict): Actions for each agent

214

215

Returns:

216

tuple: (obs_dict, reward_dict, done_dict, info_dict)

217

"""

218

219

def render(self, mode="human"):

220

"""Render environment."""

221

222

def close(self):

223

"""Close environment."""

224

225

def make_multi_agent(env_name_or_creator):

226

"""

227

Create multi-agent version of environment.

228

229

Args:

230

env_name_or_creator: Environment name or creator function

231

232

Returns:

233

MultiAgentEnv: Multi-agent environment

234

"""

235

236

class BaseEnv:

237

"""Base class for vectorized environments."""

238

239

def poll(self):

240

"""

241

Poll for completed episodes.

242

243

Returns:

244

tuple: (obs_dict, reward_dict, done_dict, info_dict, off_policy_actions_dict)

245

"""

246

247

def send_actions(self, action_dict):

248

"""

249

Send actions to environments.

250

251

Args:

252

action_dict (dict): Actions for each environment

253

"""

254

255

def try_reset(self, env_id):

256

"""

257

Try to reset specific environment.

258

259

Args:

260

env_id: Environment ID

261

262

Returns:

263

dict or None: Observation if reset successful

264

"""

265

```

266

267

### Configuration and Spaces

268

269

Configure algorithms and define spaces.

270

271

```python { .api }

272

class AlgorithmConfig:

273

"""Configuration for RL algorithms."""

274

275

def __init__(self, algo_class=None):

276

"""Initialize algorithm configuration."""

277

278

def environment(self, env=None, *, env_config=None, observation_space=None,

279

action_space=None, **kwargs):

280

"""

281

Configure environment settings.

282

283

Args:

284

env: Environment or environment string

285

env_config (dict, optional): Environment configuration

286

observation_space: Observation space

287

action_space: Action space

288

289

Returns:

290

AlgorithmConfig: Self for chaining

291

"""

292

293

def framework(self, framework=None, *, eager_tracing=None, **kwargs):

294

"""

295

Configure ML framework.

296

297

Args:

298

framework (str, optional): Framework ("tf", "tf2", "torch")

299

eager_tracing (bool, optional): Enable eager tracing

300

301

Returns:

302

AlgorithmConfig: Self for chaining

303

"""

304

305

def resources(self, *, num_gpus=None, num_cpus_per_worker=None,

306

num_gpus_per_worker=None, **kwargs):

307

"""

308

Configure resource usage.

309

310

Args:

311

num_gpus (float, optional): Number of GPUs

312

num_cpus_per_worker (float, optional): CPUs per worker

313

num_gpus_per_worker (float, optional): GPUs per worker

314

315

Returns:

316

AlgorithmConfig: Self for chaining

317

"""

318

319

def rollouts(self, *, num_rollout_workers=None, num_envs_per_worker=None,

320

rollout_fragment_length=None, **kwargs):

321

"""

322

Configure rollout collection.

323

324

Args:

325

num_rollout_workers (int, optional): Number of rollout workers

326

num_envs_per_worker (int, optional): Environments per worker

327

rollout_fragment_length (int, optional): Rollout fragment length

328

329

Returns:

330

AlgorithmConfig: Self for chaining

331

"""

332

333

def training(self, *, lr=None, train_batch_size=None, **kwargs):

334

"""

335

Configure training settings.

336

337

Args:

338

lr (float, optional): Learning rate

339

train_batch_size (int, optional): Training batch size

340

341

Returns:

342

AlgorithmConfig: Self for chaining

343

"""

344

345

def evaluation(self, *, evaluation_interval=None, evaluation_duration=None,

346

**kwargs):

347

"""

348

Configure evaluation settings.

349

350

Args:

351

evaluation_interval (int, optional): Evaluation interval

352

evaluation_duration (int, optional): Evaluation duration

353

354

Returns:

355

AlgorithmConfig: Self for chaining

356

"""

357

358

def build(self, env=None, logger_creator=None):

359

"""

360

Build algorithm from configuration.

361

362

Args:

363

env: Environment override

364

logger_creator: Logger creator override

365

366

Returns:

367

Algorithm: Built algorithm

368

"""

369

```

370

371

### Specific RL Algorithms

372

373

Implementations of specific RL algorithms.

374

375

```python { .api }

376

class PPOConfig(AlgorithmConfig):

377

"""Configuration for Proximal Policy Optimization."""

378

379

class PPO(Algorithm):

380

"""Proximal Policy Optimization algorithm."""

381

382

class SACConfig(AlgorithmConfig):

383

"""Configuration for Soft Actor-Critic."""

384

385

class SAC(Algorithm):

386

"""Soft Actor-Critic algorithm."""

387

388

class DQNConfig(AlgorithmConfig):

389

"""Configuration for Deep Q-Network."""

390

391

class DQN(Algorithm):

392

"""Deep Q-Network algorithm."""

393

394

class A3CConfig(AlgorithmConfig):

395

"""Configuration for Asynchronous Advantage Actor-Critic."""

396

397

class A3C(Algorithm):

398

"""Asynchronous Advantage Actor-Critic algorithm."""

399

400

class IMPALAConfig(AlgorithmConfig):

401

"""Configuration for IMPALA."""

402

403

class IMPALA(Algorithm):

404

"""IMPALA algorithm."""

405

```

406

407

### Utilities and Helpers

408

409

Utility functions for RL development.

410

411

```python { .api }

412

def register_env(name, env_creator):

413

"""

414

Register environment with Ray RLlib.

415

416

Args:

417

name (str): Environment name

418

env_creator: Function that creates environment

419

"""

420

421

class ModelCatalog:

422

"""Catalog for registering custom models and preprocessors."""

423

424

@staticmethod

425

def register_custom_model(model_name, model_class):

426

"""

427

Register custom model.

428

429

Args:

430

model_name (str): Model name

431

model_class: Model class

432

"""

433

434

@staticmethod

435

def register_custom_preprocessor(preprocessor_name, preprocessor_class):

436

"""

437

Register custom preprocessor.

438

439

Args:

440

preprocessor_name (str): Preprocessor name

441

preprocessor_class: Preprocessor class

442

"""

443

444

@staticmethod

445

def register_custom_action_dist(action_dist_name, action_dist_class):

446

"""

447

Register custom action distribution.

448

449

Args:

450

action_dist_name (str): Action distribution name

451

action_dist_class: Action distribution class

452

"""

453

454

def rollout(agent, env_name, num_steps=None, num_episodes=1,

455

no_render=False, video_dir=None):

456

"""

457

Rollout trained agent in environment.

458

459

Args:

460

agent: Trained agent/algorithm

461

env_name (str): Environment name

462

num_steps (int, optional): Number of steps

463

num_episodes (int): Number of episodes

464

no_render (bool): Whether to disable rendering

465

video_dir (str, optional): Directory to save videos

466

467

Returns:

468

list: Episode rewards

469

"""

470

```

471

472

## Usage Examples

473

474

### Basic RL Training

475

476

```python

477

import ray

478

from ray.rllib.algorithms.ppo import PPOConfig

479

480

# Initialize Ray

481

ray.init()

482

483

# Configure PPO algorithm

484

config = (PPOConfig()

485

.environment(env="CartPole-v1")

486

.rollouts(num_rollout_workers=2)

487

.training(lr=0.0001, train_batch_size=4000)

488

.evaluation(evaluation_interval=10))

489

490

# Build algorithm

491

algo = config.build()

492

493

# Training loop

494

for i in range(100):

495

result = algo.train()

496

print(f"Iteration {i}: reward={result['episode_reward_mean']}")

497

498

# Save checkpoint every 10 iterations

499

if i % 10 == 0:

500

checkpoint_path = algo.save()

501

print(f"Checkpoint saved at {checkpoint_path}")

502

503

# Clean up

504

algo.stop()

505

ray.shutdown()

506

```

507

508

### Custom Environment

509

510

```python

511

import ray

512

from ray.rllib.env.env_context import EnvContext

513

from ray.rllib.algorithms.dqn import DQNConfig

514

import gym

515

516

class CustomEnv(gym.Env):

517

def __init__(self, config: EnvContext):

518

self.action_space = gym.spaces.Discrete(2)

519

self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))

520

self.config = config

521

522

def reset(self):

523

return self.observation_space.sample()

524

525

def step(self, action):

526

obs = self.observation_space.sample()

527

reward = 1.0 if action == 1 else 0.0

528

done = False

529

info = {}

530

return obs, reward, done, info

531

532

# Register environment

533

from ray.rllib.utils import register_env

534

register_env("custom_env", lambda config: CustomEnv(config))

535

536

ray.init()

537

538

# Train on custom environment

539

config = (DQNConfig()

540

.environment(env="custom_env", env_config={"param": "value"})

541

.training(lr=0.001))

542

543

algo = config.build()

544

545

for i in range(50):

546

result = algo.train()

547

print(f"Episode reward: {result['episode_reward_mean']}")

548

549

algo.stop()

550

```

551

552

### Multi-Agent RL

553

554

```python

555

import ray

556

from ray.rllib.env.multi_agent_env import MultiAgentEnv

557

from ray.rllib.algorithms.ppo import PPOConfig

558

import gym

559

560

class MultiAgentCustomEnv(MultiAgentEnv):

561

def __init__(self, config):

562

self.agents = ["agent_1", "agent_2"]

563

self.action_space = gym.spaces.Discrete(2)

564

self.observation_space = gym.spaces.Box(-1, 1, shape=(4,))

565

566

def reset(self):

567

return {agent: self.observation_space.sample()

568

for agent in self.agents}

569

570

def step(self, action_dict):

571

obs = {agent: self.observation_space.sample()

572

for agent in self.agents}

573

rewards = {agent: 1.0 for agent in self.agents}

574

dones = {"__all__": False}

575

infos = {agent: {} for agent in self.agents}

576

return obs, rewards, dones, infos

577

578

register_env("multi_agent_env", lambda _: MultiAgentCustomEnv({}))

579

580

ray.init()

581

582

config = (PPOConfig()

583

.environment(env="multi_agent_env")

584

.multi_agent(

585

policies={

586

"policy_1": (None, None, None, {}),

587

"policy_2": (None, None, None, {}),

588

},

589

policy_mapping_fn=lambda agent_id, episode, **kwargs:

590

"policy_1" if agent_id == "agent_1" else "policy_2"

591

))

592

593

algo = config.build()

594

595

for i in range(30):

596

result = algo.train()

597

print(f"Iteration {i}: {result['episode_reward_mean']}")

598

599

algo.stop()

600

```

601

602

### Custom Model

603

604

```python

605

import ray

606

from ray.rllib.models import ModelCatalog

607

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

608

from ray.rllib.algorithms.ppo import PPOConfig

609

import torch

610

import torch.nn as nn

611

612

class CustomModel(TorchModelV2, nn.Module):

613

def __init__(self, obs_space, action_space, num_outputs,

614

model_config, name):

615

TorchModelV2.__init__(self, obs_space, action_space,

616

num_outputs, model_config, name)

617

nn.Module.__init__(self)

618

619

self.shared_layers = nn.Sequential(

620

nn.Linear(obs_space.shape[0], 128),

621

nn.ReLU(),

622

nn.Linear(128, 128),

623

nn.ReLU(),

624

)

625

626

self.policy_head = nn.Linear(128, num_outputs)

627

self.value_head = nn.Linear(128, 1)

628

self._value = None

629

630

def forward(self, input_dict, state, seq_lens):

631

features = self.shared_layers(input_dict["obs"])

632

logits = self.policy_head(features)

633

self._value = self.value_head(features).squeeze(1)

634

return logits, state

635

636

def value_function(self):

637

return self._value

638

639

# Register custom model

640

ModelCatalog.register_custom_model("custom_model", CustomModel)

641

642

ray.init()

643

644

config = (PPOConfig()

645

.environment(env="CartPole-v1")

646

.training(model={"custom_model": "custom_model"}))

647

648

algo = config.build()

649

650

for i in range(20):

651

result = algo.train()

652

print(f"Reward: {result['episode_reward_mean']}")

653

654

algo.stop()

655

```

656

657

### Loading and Using Trained Agent

658

659

```python

660

import ray

661

from ray.rllib.algorithms.ppo import PPO

662

import gym

663

664

ray.init()

665

666

# Restore trained algorithm

667

algo = PPO.from_checkpoint("/path/to/checkpoint")

668

669

# Create environment for evaluation

670

env = gym.make("CartPole-v1")

671

672

# Run episodes with trained agent

673

for episode in range(5):

674

obs = env.reset()

675

done = False

676

total_reward = 0

677

678

while not done:

679

action = algo.compute_single_action(obs)

680

obs, reward, done, info = env.step(action)

681

total_reward += reward

682

env.render()

683

684

print(f"Episode {episode}: Total reward = {total_reward}")

685

686

env.close()

687

algo.stop()

688

```