0
# Hyperparameter Optimization
1
2
Hyperparameter sampling and optimization utilities using Optuna. Provides algorithm-specific parameter samplers, conversion functions, and distributed optimization support for finding optimal hyperparameters across different RL algorithms.
3
4
## Core Imports
5
6
```python
7
from rl_zoo3.hyperparams_opt import (
8
sample_ppo_params,
9
sample_sac_params,
10
sample_dqn_params,
11
sample_td3_params,
12
sample_a2c_params,
13
sample_ars_params,
14
convert_onpolicy_params,
15
convert_offpolicy_params,
16
convert_ars_params
17
)
18
import optuna
19
from typing import Any, dict
20
```
21
22
## Capabilities
23
24
### Parameter Conversion Functions
25
26
Functions for converting sampled hyperparameters into the format expected by different algorithm families.
27
28
```python { .api }
29
def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
30
"""
31
Convert sampled hyperparameters for on-policy algorithms (PPO, A2C, TRPO).
32
33
Parameters:
34
- sampled_params: Raw hyperparameters from Optuna sampling
35
36
Returns:
37
dict: Converted hyperparameters ready for algorithm use
38
"""
39
40
def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
41
"""
42
Convert sampled hyperparameters for off-policy algorithms (SAC, TD3, DQN).
43
44
Parameters:
45
- sampled_params: Raw hyperparameters from Optuna sampling
46
47
Returns:
48
dict: Converted hyperparameters ready for algorithm use
49
"""
50
51
def convert_ars_params(sampled_params: dict[str, Any]) -> dict[str, Any]:
52
"""
53
Convert sampled hyperparameters for ARS algorithm.
54
55
Parameters:
56
- sampled_params: Raw hyperparameters from Optuna sampling
57
58
Returns:
59
dict: Converted ARS-specific hyperparameters
60
"""
61
```
62
63
### PPO Parameter Sampling
64
65
Sampling functions for Proximal Policy Optimization hyperparameters.
66
67
```python { .api }
68
def sample_ppo_params(
69
trial: optuna.Trial,
70
n_actions: int,
71
n_envs: int,
72
additional_args: dict
73
) -> dict[str, Any]:
74
"""
75
Sample hyperparameters for PPO algorithm.
76
77
Parameters:
78
- trial: Optuna trial object for parameter sampling
79
- n_actions: Number of actions in the action space
80
- n_envs: Number of parallel environments
81
- additional_args: Additional algorithm-specific arguments
82
83
Returns:
84
dict: Sampled PPO hyperparameters including learning_rate, n_steps,
85
batch_size, n_epochs, gamma, gae_lambda, clip_range, ent_coef, etc.
86
"""
87
88
def sample_ppo_lstm_params(
89
trial: optuna.Trial,
90
n_actions: int,
91
n_envs: int,
92
additional_args: dict
93
) -> dict[str, Any]:
94
"""
95
Sample hyperparameters for PPO with LSTM policy.
96
97
Parameters:
98
- trial: Optuna trial object
99
- n_actions: Number of actions
100
- n_envs: Number of environments
101
- additional_args: Additional arguments
102
103
Returns:
104
dict: Sampled PPO-LSTM hyperparameters with LSTM-specific parameters
105
"""
106
```
107
108
### SAC Parameter Sampling
109
110
Sampling functions for Soft Actor-Critic hyperparameters.
111
112
```python { .api }
113
def sample_sac_params(
114
trial: optuna.Trial,
115
n_actions: int,
116
n_envs: int,
117
additional_args: dict
118
) -> dict[str, Any]:
119
"""
120
Sample hyperparameters for SAC algorithm.
121
122
Parameters:
123
- trial: Optuna trial object for parameter sampling
124
- n_actions: Number of actions in the action space
125
- n_envs: Number of parallel environments (typically 1 for SAC)
126
- additional_args: Additional algorithm-specific arguments
127
128
Returns:
129
dict: Sampled SAC hyperparameters including learning_rate, buffer_size,
130
batch_size, tau, gamma, train_freq, gradient_steps, ent_coef, etc.
131
"""
132
```
133
134
### DQN Parameter Sampling
135
136
Sampling functions for Deep Q-Network and its variants.
137
138
```python { .api }
139
def sample_dqn_params(
140
trial: optuna.Trial,
141
n_actions: int,
142
n_envs: int,
143
additional_args: dict
144
) -> dict[str, Any]:
145
"""
146
Sample hyperparameters for DQN algorithm.
147
148
Parameters:
149
- trial: Optuna trial object
150
- n_actions: Number of discrete actions
151
- n_envs: Number of environments
152
- additional_args: Additional arguments
153
154
Returns:
155
dict: Sampled DQN hyperparameters including learning_rate, buffer_size,
156
batch_size, tau, gamma, train_freq, target_update_interval, etc.
157
"""
158
159
def sample_qrdqn_params(
160
trial: optuna.Trial,
161
n_actions: int,
162
n_envs: int,
163
additional_args: dict
164
) -> dict[str, Any]:
165
"""
166
Sample hyperparameters for QR-DQN (Quantile Regression DQN).
167
168
Parameters:
169
- trial: Optuna trial object
170
- n_actions: Number of actions
171
- n_envs: Number of environments
172
- additional_args: Additional arguments
173
174
Returns:
175
dict: Sampled QR-DQN hyperparameters with quantile-specific parameters
176
"""
177
```
178
179
### TD3 Parameter Sampling
180
181
Sampling functions for Twin Delayed Deep Deterministic Policy Gradient.
182
183
```python { .api }
184
def sample_td3_params(
185
trial: optuna.Trial,
186
n_actions: int,
187
n_envs: int,
188
additional_args: dict
189
) -> dict[str, Any]:
190
"""
191
Sample hyperparameters for TD3 algorithm.
192
193
Parameters:
194
- trial: Optuna trial object
195
- n_actions: Number of continuous actions
196
- n_envs: Number of environments
197
- additional_args: Additional arguments
198
199
Returns:
200
dict: Sampled TD3 hyperparameters including learning_rate, buffer_size,
201
batch_size, tau, gamma, train_freq, policy_delay, target_policy_noise, etc.
202
"""
203
```
204
205
### A2C Parameter Sampling
206
207
Sampling functions for Advantage Actor-Critic.
208
209
```python { .api }
210
def sample_a2c_params(
211
trial: optuna.Trial,
212
n_actions: int,
213
n_envs: int,
214
additional_args: dict
215
) -> dict[str, Any]:
216
"""
217
Sample hyperparameters for A2C algorithm.
218
219
Parameters:
220
- trial: Optuna trial object
221
- n_actions: Number of actions
222
- n_envs: Number of parallel environments
223
- additional_args: Additional arguments
224
225
Returns:
226
dict: Sampled A2C hyperparameters including learning_rate, n_steps,
227
gamma, gae_lambda, ent_coef, vf_coef, etc.
228
"""
229
```
230
231
### TRPO Parameter Sampling
232
233
Sampling functions for Trust Region Policy Optimization.
234
235
```python { .api }
236
def sample_trpo_params(
237
trial: optuna.Trial,
238
n_actions: int,
239
n_envs: int,
240
additional_args: dict
241
) -> dict[str, Any]:
242
"""
243
Sample hyperparameters for TRPO algorithm.
244
245
Parameters:
246
- trial: Optuna trial object
247
- n_actions: Number of actions
248
- n_envs: Number of environments
249
- additional_args: Additional arguments
250
251
Returns:
252
dict: Sampled TRPO hyperparameters including learning_rate, n_steps,
253
batch_size, gamma, gae_lambda, cg_max_steps, target_kl, etc.
254
"""
255
```
256
257
### TQC Parameter Sampling
258
259
Sampling functions for Truncated Quantile Critics.
260
261
```python { .api }
262
def sample_tqc_params(
263
trial: optuna.Trial,
264
n_actions: int,
265
n_envs: int,
266
additional_args: dict
267
) -> dict[str, Any]:
268
"""
269
Sample hyperparameters for TQC algorithm.
270
271
Parameters:
272
- trial: Optuna trial object
273
- n_actions: Number of actions
274
- n_envs: Number of environments
275
- additional_args: Additional arguments
276
277
Returns:
278
dict: Sampled TQC hyperparameters with quantile critic parameters
279
"""
280
```
281
282
### ARS Parameter Sampling
283
284
Sampling functions for Augmented Random Search.
285
286
```python { .api }
287
def sample_ars_params(
288
trial: optuna.Trial,
289
n_actions: int,
290
n_envs: int,
291
additional_args: dict
292
) -> dict[str, Any]:
293
"""
294
Sample hyperparameters for ARS algorithm.
295
296
Parameters:
297
- trial: Optuna trial object
298
- n_actions: Number of actions
299
- n_envs: Number of environments
300
- additional_args: Additional arguments
301
302
Returns:
303
dict: Sampled ARS hyperparameters including n_delta, n_top, learning_rate,
304
delta_std, zero_policy, etc.
305
"""
306
```
307
308
### HER Parameter Sampling
309
310
Sampling functions for Hindsight Experience Replay parameters.
311
312
```python { .api }
313
def sample_her_params(
314
trial: optuna.Trial,
315
hyperparams: dict[str, Any],
316
her_kwargs: dict[str, Any]
317
) -> dict[str, Any]:
318
"""
319
Sample hyperparameters for HER (Hindsight Experience Replay).
320
321
Parameters:
322
- trial: Optuna trial object
323
- hyperparams: Base algorithm hyperparameters
324
- her_kwargs: HER-specific keyword arguments
325
326
Returns:
327
dict: Updated hyperparameters with HER configuration
328
"""
329
```
330
331
## Usage Examples
332
333
### Basic Hyperparameter Optimization
334
335
```python
336
import optuna
337
from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params
338
from rl_zoo3.exp_manager import ExperimentManager
339
from rl_zoo3 import ALGOS
340
import argparse
341
342
def objective(trial):
343
# Sample hyperparameters
344
sampled_params = sample_ppo_params(
345
trial=trial,
346
n_actions=2, # CartPole has 2 actions
347
n_envs=4,
348
additional_args={}
349
)
350
351
# Convert parameters
352
hyperparams = convert_onpolicy_params(sampled_params)
353
354
# Create experiment manager
355
args = argparse.Namespace(
356
algo='ppo',
357
env='CartPole-v1',
358
n_timesteps=10000,
359
eval_freq=2000,
360
n_eval_episodes=5,
361
verbose=0
362
)
363
364
exp_manager = ExperimentManager(
365
args=args,
366
algo='ppo',
367
env_id='CartPole-v1',
368
log_folder='./optim_logs',
369
hyperparams=hyperparams,
370
n_timesteps=10000,
371
eval_freq=2000
372
)
373
374
# Setup and train
375
model = exp_manager.setup_experiment()
376
exp_manager.learn(model)
377
378
# Return performance metric
379
# (In practice, this would be extracted from evaluation callback)
380
return 200.0 # Placeholder reward
381
382
# Run optimization
383
study = optuna.create_study(direction='maximize')
384
study.optimize(objective, n_trials=20)
385
386
print("Best parameters:", study.best_params)
387
print("Best value:", study.best_value)
388
```
389
390
### Multi-Algorithm Optimization
391
392
```python
393
import optuna
394
from rl_zoo3.hyperparams_opt import (
395
sample_ppo_params, sample_sac_params,
396
convert_onpolicy_params, convert_offpolicy_params
397
)
398
399
def multi_algo_objective(trial):
400
# Select algorithm
401
algo_name = trial.suggest_categorical('algorithm', ['ppo', 'sac'])
402
403
if algo_name == 'ppo':
404
sampled_params = sample_ppo_params(trial, n_actions=2, n_envs=4, additional_args={})
405
hyperparams = convert_onpolicy_params(sampled_params)
406
elif algo_name == 'sac':
407
sampled_params = sample_sac_params(trial, n_actions=1, n_envs=1, additional_args={})
408
hyperparams = convert_offpolicy_params(sampled_params)
409
410
# Create and train model with selected algorithm and parameters
411
# ... (training code similar to above example)
412
413
return performance_score
414
415
# Optimize across algorithms
416
study = optuna.create_study(direction='maximize')
417
study.optimize(multi_algo_objective, n_trials=50)
418
```
419
420
### Distributed Optimization
421
422
```python
423
import optuna
424
from rl_zoo3.exp_manager import ExperimentManager
425
426
def create_distributed_study():
427
# Create study with database storage for distributed optimization
428
study = optuna.create_study(
429
study_name='rl_zoo3_optimization',
430
storage='sqlite:///optuna_study.db',
431
direction='maximize',
432
load_if_exists=True
433
)
434
return study
435
436
def distributed_objective(trial):
437
# Sample parameters for chosen algorithm
438
algo = 'ppo' # Could be parameterized
439
440
if algo == 'ppo':
441
from rl_zoo3.hyperparams_opt import sample_ppo_params, convert_onpolicy_params
442
sampled_params = sample_ppo_params(trial, n_actions=4, n_envs=8, additional_args={})
443
hyperparams = convert_onpolicy_params(sampled_params)
444
445
# Create experiment manager with optimization settings
446
args = argparse.Namespace(
447
algo=algo,
448
env='LunarLander-v2',
449
n_timesteps=50000,
450
eval_freq=5000,
451
n_eval_episodes=10,
452
verbose=0,
453
seed=trial.suggest_int('seed', 0, 2**32-1)
454
)
455
456
exp_manager = ExperimentManager(
457
args=args,
458
algo=algo,
459
env_id='LunarLander-v2',
460
log_folder=f'./optim_logs/trial_{trial.number}',
461
hyperparams=hyperparams,
462
n_timesteps=50000,
463
eval_freq=5000
464
)
465
466
# Train and evaluate
467
model = exp_manager.setup_experiment()
468
exp_manager.learn(model)
469
470
# Extract performance (typically from evaluation callback)
471
return trial.suggest_float('mock_performance', -500, 500) # Placeholder
472
473
# Run distributed optimization
474
study = create_distributed_study()
475
study.optimize(distributed_objective, n_trials=10) # Each process runs 10 trials
476
```
477
478
### Custom Parameter Sampling
479
480
```python
481
import optuna
482
from rl_zoo3.hyperparams_opt import convert_onpolicy_params
483
484
def sample_custom_ppo_params(trial, n_actions, n_envs, additional_args):
485
"""
486
Custom PPO parameter sampling with different ranges.
487
"""
488
# Learning rate with log-uniform distribution
489
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
490
491
# Batch size as power of 2
492
batch_size_exp = trial.suggest_int('batch_size_exp', 4, 8) # 2^4 to 2^8
493
batch_size = 2 ** batch_size_exp
494
495
# Number of steps
496
n_steps = trial.suggest_categorical('n_steps', [128, 256, 512, 1024, 2048])
497
498
# Ensure batch_size <= n_steps * n_envs
499
if batch_size > n_steps * n_envs:
500
batch_size = n_steps * n_envs
501
502
# Other hyperparameters
503
gamma = trial.suggest_float('gamma', 0.9, 0.9999)
504
gae_lambda = trial.suggest_float('gae_lambda', 0.8, 1.0)
505
clip_range = trial.suggest_float('clip_range', 0.1, 0.4)
506
ent_coef = trial.suggest_float('ent_coef', 1e-8, 1e-1, log=True)
507
508
return {
509
'learning_rate': learning_rate,
510
'n_steps': n_steps,
511
'batch_size': batch_size,
512
'gamma': gamma,
513
'gae_lambda': gae_lambda,
514
'clip_range': clip_range,
515
'ent_coef': ent_coef,
516
'n_epochs': trial.suggest_int('n_epochs', 3, 10),
517
'vf_coef': trial.suggest_float('vf_coef', 0.1, 1.0)
518
}
519
520
# Use custom sampling in optimization
521
def custom_objective(trial):
522
sampled_params = sample_custom_ppo_params(
523
trial, n_actions=4, n_envs=8, additional_args={}
524
)
525
hyperparams = convert_onpolicy_params(sampled_params)
526
527
# ... rest of training code
528
return performance
529
```
530
531
### Integration with ExperimentManager
532
533
```python
534
from rl_zoo3.exp_manager import ExperimentManager
535
import argparse
536
537
# ExperimentManager handles optimization automatically
538
args = argparse.Namespace(
539
algo='ppo',
540
env='CartPole-v1',
541
n_timesteps=20000,
542
eval_freq=2000,
543
optimize_hyperparameters=True, # Enable optimization
544
n_trials=30,
545
n_jobs=2,
546
sampler='tpe',
547
pruner='median',
548
study_name='ppo_cartpole_study',
549
storage='sqlite:///ppo_optimization.db'
550
)
551
552
# Automatic hyperparameter optimization
553
exp_manager = ExperimentManager(
554
args=args,
555
algo='ppo',
556
env_id='CartPole-v1',
557
log_folder='./optim_logs',
558
optimize_hyperparameters=True,
559
n_trials=30,
560
n_jobs=2,
561
sampler='tpe',
562
pruner='median'
563
)
564
565
# This will run the full optimization process
566
exp_manager.hyperparameters_optimization()
567
```
568
569
## Supported Samplers and Pruners
570
571
The optimization system supports various Optuna samplers and pruners:
572
573
**Samplers:**
574
- `'tpe'`: Tree-structured Parzen Estimator (default, good for most cases)
575
- `'random'`: Random sampling (baseline)
576
- `'cmaes'`: CMA-ES (good for continuous parameters)
577
578
**Pruners:**
579
- `'median'`: Median pruner (default, prunes below median performance)
580
- `'successive_halving'`: Successive halving (aggressive pruning)
581
- `'hyperband'`: Hyperband (adaptive resource allocation)
582
- `'nop'`: No pruning