or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md

hyperparameter-optimization.mddocs/

0

# Hyperparameter Optimization

1

2

Hyperparameter sampling and optimization utilities using Optuna. Provides algorithm-specific parameter samplers, conversion functions, and distributed optimization support for finding optimal hyperparameters across different RL algorithms.

3

4

## Core Imports

5

6

```python

7

from rl_zoo3.hyperparams_opt import (

8

sample_ppo_params,

9

sample_sac_params,

10

sample_dqn_params,

11

sample_td3_params,

12

sample_a2c_params,

13

sample_ars_params,

14

convert_onpolicy_params,

15

convert_offpolicy_params,

16

convert_ars_params

17

)

18

import optuna

19

from typing import Any, dict

20

```

21

22

## Capabilities

23

24

### Parameter Conversion Functions

25

26

Functions for converting sampled hyperparameters into the format expected by different algorithm families.

27

28

```python { .api }

29

def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:

30

"""

31

Convert sampled hyperparameters for on-policy algorithms (PPO, A2C, TRPO).

32

33

Parameters:

34

- sampled_params: Raw hyperparameters from Optuna sampling

35

36

Returns:

37

dict: Converted hyperparameters ready for algorithm use

38

"""

39

40

def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:

41

"""

42

Convert sampled hyperparameters for off-policy algorithms (SAC, TD3, DQN).

43

44

Parameters:

45

- sampled_params: Raw hyperparameters from Optuna sampling

46

47

Returns:

48

dict: Converted hyperparameters ready for algorithm use

49

"""

50

51

def convert_ars_params(sampled_params: dict[str, Any]) -> dict[str, Any]:

52

"""

53

Convert sampled hyperparameters for ARS algorithm.

54

55

Parameters:

56

- sampled_params: Raw hyperparameters from Optuna sampling

57

58

Returns:

59

dict: Converted ARS-specific hyperparameters

60

"""

61

```

62

63

### PPO Parameter Sampling

64

65

Sampling functions for Proximal Policy Optimization hyperparameters.

66

67

```python { .api }

68

def sample_ppo_params(

69

trial: optuna.Trial,

70

n_actions: int,

71

n_envs: int,

72

additional_args: dict

73

) -> dict[str, Any]:

74

"""

75

Sample hyperparameters for PPO algorithm.

76

77

Parameters:

78

- trial: Optuna trial object for parameter sampling

79

- n_actions: Number of actions in the action space

80

- n_envs: Number of parallel environments

81

- additional_args: Additional algorithm-specific arguments

82

83

Returns:

84

dict: Sampled PPO hyperparameters including learning_rate, n_steps,

85

batch_size, n_epochs, gamma, gae_lambda, clip_range, ent_coef, etc.

86

"""

87

88

def sample_ppo_lstm_params(

89

trial: optuna.Trial,

90

n_actions: int,

91

n_envs: int,

92

additional_args: dict

93

) -> dict[str, Any]:

94

"""

95

Sample hyperparameters for PPO with LSTM policy.

96

97

Parameters:

98

- trial: Optuna trial object

99

- n_actions: Number of actions

100

- n_envs: Number of environments

101

- additional_args: Additional arguments

102

103

Returns:

104

dict: Sampled PPO-LSTM hyperparameters with LSTM-specific parameters

105

"""

106

```

107

108

### SAC Parameter Sampling

109

110

Sampling functions for Soft Actor-Critic hyperparameters.

111

112

```python { .api }

113

def sample_sac_params(

114

trial: optuna.Trial,

115

n_actions: int,

116

n_envs: int,

117

additional_args: dict

118

) -> dict[str, Any]:

119

"""

120

Sample hyperparameters for SAC algorithm.

121

122

Parameters:

123

- trial: Optuna trial object for parameter sampling

124

- n_actions: Number of actions in the action space

125

- n_envs: Number of parallel environments (typically 1 for SAC)

126

- additional_args: Additional algorithm-specific arguments

127

128

Returns:

129

dict: Sampled SAC hyperparameters including learning_rate, buffer_size,

130

batch_size, tau, gamma, train_freq, gradient_steps, ent_coef, etc.

131

"""

132

```

133

134

### DQN Parameter Sampling

135

136

Sampling functions for Deep Q-Network and its variants.

137

138

```python { .api }

139

def sample_dqn_params(

140

trial: optuna.Trial,

141

n_actions: int,

142

n_envs: int,

143

additional_args: dict

144

) -> dict[str, Any]:

145

"""

146

Sample hyperparameters for DQN algorithm.

147

148

Parameters:

149

- trial: Optuna trial object

150

- n_actions: Number of discrete actions

151

- n_envs: Number of environments

152

- additional_args: Additional arguments

153

154

Returns:

155

dict: Sampled DQN hyperparameters including learning_rate, buffer_size,

156

batch_size, tau, gamma, train_freq, target_update_interval, etc.

157

"""

158

159

def sample_qrdqn_params(

160

trial: optuna.Trial,

161

n_actions: int,

162

n_envs: int,

163

additional_args: dict

164

) -> dict[str, Any]:

165

"""

166

Sample hyperparameters for QR-DQN (Quantile Regression DQN).

167

168

Parameters:

169

- trial: Optuna trial object

170

- n_actions: Number of actions

171

- n_envs: Number of environments

172

- additional_args: Additional arguments

173

174

Returns:

175

dict: Sampled QR-DQN hyperparameters with quantile-specific parameters

176

"""

177

```

178

179

### TD3 Parameter Sampling

180

181

Sampling functions for Twin Delayed Deep Deterministic Policy Gradient.

182

183

```python { .api }

184

def sample_td3_params(

185

trial: optuna.Trial,

186

n_actions: int,

187

n_envs: int,

188

additional_args: dict

189

) -> dict[str, Any]:

190

"""

191

Sample hyperparameters for TD3 algorithm.

192

193

Parameters:

194

- trial: Optuna trial object

195

- n_actions: Number of continuous actions

196

- n_envs: Number of environments

197

- additional_args: Additional arguments

198

199

Returns:

200

dict: Sampled TD3 hyperparameters including learning_rate, buffer_size,

201

batch_size, tau, gamma, train_freq, policy_delay, target_policy_noise, etc.

202

"""

203

```

204

205

### A2C Parameter Sampling

206

207

Sampling functions for Advantage Actor-Critic.

208

209

```python { .api }

210

def sample_a2c_params(

211

trial: optuna.Trial,

212

n_actions: int,

213

n_envs: int,

214

additional_args: dict

215

) -> dict[str, Any]:

216

"""

217

Sample hyperparameters for A2C algorithm.

218

219

Parameters:

220

- trial: Optuna trial object

221

- n_actions: Number of actions

222

- n_envs: Number of parallel environments

223

- additional_args: Additional arguments

224

225

Returns:

226

dict: Sampled A2C hyperparameters including learning_rate, n_steps,

227

gamma, gae_lambda, ent_coef, vf_coef, etc.

228

"""

229

```

230

231

### TRPO Parameter Sampling

232

233

Sampling functions for Trust Region Policy Optimization.

234

235

```python { .api }

236

def sample_trpo_params(

237

trial: optuna.Trial,

238

n_actions: int,

239

n_envs: int,

240

additional_args: dict

241

) -> dict[str, Any]:

242

"""

243

Sample hyperparameters for TRPO algorithm.

244

245

Parameters:

246

- trial: Optuna trial object

247

- n_actions: Number of actions

248

- n_envs: Number of environments

249

- additional_args: Additional arguments

250

251

Returns:

252

dict: Sampled TRPO hyperparameters including learning_rate, n_steps,

253

batch_size, gamma, gae_lambda, cg_max_steps, target_kl, etc.

254

"""

255

```

256

257

### TQC Parameter Sampling

258

259

Sampling functions for Truncated Quantile Critics.

260

261

```python { .api }

262

def sample_tqc_params(

263

trial: optuna.Trial,

264

n_actions: int,

265

n_envs: int,

266

additional_args: dict

267

) -> dict[str, Any]:

268

"""

269

Sample hyperparameters for TQC algorithm.

270

271

Parameters:

272

- trial: Optuna trial object

273

- n_actions: Number of actions

274

- n_envs: Number of environments

275

- additional_args: Additional arguments

276

277

Returns:

278

dict: Sampled TQC hyperparameters with quantile critic parameters

279

"""

280

```

281

282

### ARS Parameter Sampling

283

284

Sampling functions for Augmented Random Search.

285

286

```python { .api }

287

def sample_ars_params(

288

trial: optuna.Trial,

289

n_actions: int,

290

n_envs: int,

291

additional_args: dict

292

) -> dict[str, Any]:

293

"""

294

Sample hyperparameters for ARS algorithm.

295

296

Parameters:

297

- trial: Optuna trial object

298

- n_actions: Number of actions

299

- n_envs: Number of environments

300

- additional_args: Additional arguments

301

302

Returns:

303

dict: Sampled ARS hyperparameters including n_delta, n_top, learning_rate,

304

delta_std, zero_policy, etc.

305

"""

306

```

307

308

### HER Parameter Sampling

309

310

Sampling functions for Hindsight Experience Replay parameters.

311

312

```python { .api }

313

def sample_her_params(

314

trial: optuna.Trial,

315

hyperparams: dict[str, Any],

316

her_kwargs: dict[str, Any]

317

) -> dict[str, Any]:

318

"""

319

Sample hyperparameters for HER (Hindsight Experience Replay).

320

321

Parameters:

322

- trial: Optuna trial object

323

- hyperparams: Base algorithm hyperparameters

324

- her_kwargs: HER-specific keyword arguments

325

326

Returns:

327

dict: Updated hyperparameters with HER configuration

328

"""

329

```

330

331

## Usage Examples

332

333

### Basic Hyperparameter Optimization

334

335

```python

336

import optuna

337

from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params

338

from rl_zoo3.exp_manager import ExperimentManager

339

from rl_zoo3 import ALGOS

340

import argparse

341

342

def objective(trial):

343

# Sample hyperparameters

344

sampled_params = sample_ppo_params(

345

trial=trial,

346

n_actions=2, # CartPole has 2 actions

347

n_envs=4,

348

additional_args={}

349

)

350

351

# Convert parameters

352

hyperparams = convert_onpolicy_params(sampled_params)

353

354

# Create experiment manager

355

args = argparse.Namespace(

356

algo='ppo',

357

env='CartPole-v1',

358

n_timesteps=10000,

359

eval_freq=2000,

360

n_eval_episodes=5,

361

verbose=0

362

)

363

364

exp_manager = ExperimentManager(

365

args=args,

366

algo='ppo',

367

env_id='CartPole-v1',

368

log_folder='./optim_logs',

369

hyperparams=hyperparams,

370

n_timesteps=10000,

371

eval_freq=2000

372

)

373

374

# Setup and train

375

model = exp_manager.setup_experiment()

376

exp_manager.learn(model)

377

378

# Return performance metric

379

# (In practice, this would be extracted from evaluation callback)

380

return 200.0 # Placeholder reward

381

382

# Run optimization

383

study = optuna.create_study(direction='maximize')

384

study.optimize(objective, n_trials=20)

385

386

print("Best parameters:", study.best_params)

387

print("Best value:", study.best_value)

388

```

389

390

### Multi-Algorithm Optimization

391

392

```python

393

import optuna

394

from rl_zoo3.hyperparams_opt import (

395

sample_ppo_params, sample_sac_params,

396

convert_onpolicy_params, convert_offpolicy_params

397

)

398

399

def multi_algo_objective(trial):

400

# Select algorithm

401

algo_name = trial.suggest_categorical('algorithm', ['ppo', 'sac'])

402

403

if algo_name == 'ppo':

404

sampled_params = sample_ppo_params(trial, n_actions=2, n_envs=4, additional_args={})

405

hyperparams = convert_onpolicy_params(sampled_params)

406

elif algo_name == 'sac':

407

sampled_params = sample_sac_params(trial, n_actions=1, n_envs=1, additional_args={})

408

hyperparams = convert_offpolicy_params(sampled_params)

409

410

# Create and train model with selected algorithm and parameters

411

# ... (training code similar to above example)

412

413

return performance_score

414

415

# Optimize across algorithms

416

study = optuna.create_study(direction='maximize')

417

study.optimize(multi_algo_objective, n_trials=50)

418

```

419

420

### Distributed Optimization

421

422

```python

423

import optuna

424

from rl_zoo3.exp_manager import ExperimentManager

425

426

def create_distributed_study():

427

# Create study with database storage for distributed optimization

428

study = optuna.create_study(

429

study_name='rl_zoo3_optimization',

430

storage='sqlite:///optuna_study.db',

431

direction='maximize',

432

load_if_exists=True

433

)

434

return study

435

436

def distributed_objective(trial):

437

# Sample parameters for chosen algorithm

438

algo = 'ppo' # Could be parameterized

439

440

if algo == 'ppo':

441

from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params

442

sampled_params = sample_ppo_params(trial, n_actions=4, n_envs=8, additional_args={})

443

hyperparams = convert_onpolicy_params(sampled_params)

444

445

# Create experiment manager with optimization settings

446

args = argparse.Namespace(

447

algo=algo,

448

env='LunarLander-v2',

449

n_timesteps=50000,

450

eval_freq=5000,

451

n_eval_episodes=10,

452

verbose=0,

453

seed=trial.suggest_int('seed', 0, 2**32-1)

454

)

455

456

exp_manager = ExperimentManager(

457

args=args,

458

algo=algo,

459

env_id='LunarLander-v2',

460

log_folder=f'./optim_logs/trial_{trial.number}',

461

hyperparams=hyperparams,

462

n_timesteps=50000,

463

eval_freq=5000

464

)

465

466

# Train and evaluate

467

model = exp_manager.setup_experiment()

468

exp_manager.learn(model)

469

470

# Extract performance (typically from evaluation callback)

471

return trial.suggest_float('mock_performance', -500, 500) # Placeholder

472

473

# Run distributed optimization

474

study = create_distributed_study()

475

study.optimize(distributed_objective, n_trials=10) # Each process runs 10 trials

476

```

477

478

### Custom Parameter Sampling

479

480

```python

481

import optuna

482

from rl_zoo3.hyperparams_opt import convert_onpolicy_params

483

484

def sample_custom_ppo_params(trial, n_actions, n_envs, additional_args):

485

"""

486

Custom PPO parameter sampling with different ranges.

487

"""

488

# Learning rate with log-uniform distribution

489

learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)

490

491

# Batch size as power of 2

492

batch_size_exp = trial.suggest_int('batch_size_exp', 4, 8) # 2^4 to 2^8

493

batch_size = 2 ** batch_size_exp

494

495

# Number of steps

496

n_steps = trial.suggest_categorical('n_steps', [128, 256, 512, 1024, 2048])

497

498

# Ensure batch_size <= n_steps * n_envs

499

if batch_size > n_steps * n_envs:

500

batch_size = n_steps * n_envs

501

502

# Other hyperparameters

503

gamma = trial.suggest_float('gamma', 0.9, 0.9999)

504

gae_lambda = trial.suggest_float('gae_lambda', 0.8, 1.0)

505

clip_range = trial.suggest_float('clip_range', 0.1, 0.4)

506

ent_coef = trial.suggest_float('ent_coef', 1e-8, 1e-1, log=True)

507

508

return {

509

'learning_rate': learning_rate,

510

'n_steps': n_steps,

511

'batch_size': batch_size,

512

'gamma': gamma,

513

'gae_lambda': gae_lambda,

514

'clip_range': clip_range,

515

'ent_coef': ent_coef,

516

'n_epochs': trial.suggest_int('n_epochs', 3, 10),

517

'vf_coef': trial.suggest_float('vf_coef', 0.1, 1.0)

518

}

519

520

# Use custom sampling in optimization

521

def custom_objective(trial):

522

sampled_params = sample_custom_ppo_params(

523

trial, n_actions=4, n_envs=8, additional_args={}

524

)

525

hyperparams = convert_onpolicy_params(sampled_params)

526

527

# ... rest of training code

528

return performance

529

```

530

531

### Integration with ExperimentManager

532

533

```python

534

from rl_zoo3.exp_manager import ExperimentManager

535

import argparse

536

537

# ExperimentManager handles optimization automatically

538

args = argparse.Namespace(

539

algo='ppo',

540

env='CartPole-v1',

541

n_timesteps=20000,

542

eval_freq=2000,

543

optimize_hyperparameters=True, # Enable optimization

544

n_trials=30,

545

n_jobs=2,

546

sampler='tpe',

547

pruner='median',

548

study_name='ppo_cartpole_study',

549

storage='sqlite:///ppo_optimization.db'

550

)

551

552

# Automatic hyperparameter optimization

553

exp_manager = ExperimentManager(

554

args=args,

555

algo='ppo',

556

env_id='CartPole-v1',

557

log_folder='./optim_logs',

558

optimize_hyperparameters=True,

559

n_trials=30,

560

n_jobs=2,

561

sampler='tpe',

562

pruner='median'

563

)

564

565

# This will run the full optimization process

566

exp_manager.hyperparameters_optimization()

567

```

568

569

## Supported Samplers and Pruners

570

571

The optimization system supports various Optuna samplers and pruners:

572

573

**Samplers:**

574

- `'tpe'`: Tree-structured Parzen Estimator (default, good for most cases)

575

- `'random'`: Random sampling (baseline)

576

- `'cmaes'`: CMA-ES (good for continuous parameters)

577

578

**Pruners:**

579

- `'median'`: Median pruner (default, prunes below median performance)

580

- `'successive_halving'`: Successive halving (aggressive pruning)

581

- `'hyperband'`: Hyperband (adaptive resource allocation)

582

- `'nop'`: No pruning