0
# Experiment Management
1
2
Comprehensive experiment orchestration through the ExperimentManager class, which handles training workflows, hyperparameter optimization, environment setup, and model coordination. This is the central component that ties together all aspects of RL training experiments.
3
4
## Core Imports
5
6
```python
7
from rl_zoo3.exp_manager import ExperimentManager
8
import argparse
9
from typing import Optional, Any
10
```
11
12
## Capabilities
13
14
### ExperimentManager Class
15
16
The main class for managing RL experiments, from initial setup through training and evaluation. Handles hyperparameter loading, environment creation, model instantiation, and training coordination.
17
18
```python { .api }
19
class ExperimentManager:
20
"""
21
Experiment manager: read the hyperparameters,
22
preprocess them, create the environment and the RL model.
23
"""
24
25
def __init__(
26
self,
27
args: argparse.Namespace,
28
algo: str,
29
env_id: str,
30
log_folder: str,
31
tensorboard_log: str = "",
32
n_timesteps: int = 0,
33
eval_freq: int = 10000,
34
n_eval_episodes: int = 5,
35
save_freq: int = -1,
36
hyperparams: Optional[dict[str, Any]] = None,
37
env_kwargs: Optional[dict[str, Any]] = None,
38
eval_env_kwargs: Optional[dict[str, Any]] = None,
39
trained_agent: str = "",
40
optimize_hyperparameters: bool = False,
41
storage: Optional[str] = None,
42
study_name: Optional[str] = None,
43
n_trials: int = 1,
44
max_total_trials: Optional[int] = None,
45
n_jobs: int = 1,
46
sampler: str = "tpe",
47
pruner: str = "median",
48
optimization_log_path: Optional[str] = None,
49
n_startup_trials: int = 0,
50
n_evaluations: int = 1,
51
truncate_last_trajectory: bool = False,
52
uuid_str: str = "",
53
seed: int = 0,
54
log_interval: int = 0,
55
save_replay_buffer: bool = False,
56
verbose: int = 1,
57
vec_env_type: str = "dummy",
58
n_eval_envs: int = 1,
59
no_optim_plots: bool = False,
60
device: Union[th.device, str] = "auto",
61
config: Optional[str] = None,
62
show_progress: bool = False,
63
trial_id: Optional[int] = None
64
):
65
"""
66
Initialize ExperimentManager.
67
68
Parameters:
69
- args: Command line arguments namespace
70
- algo: Algorithm name (must be in ALGOS dict)
71
- env_id: Environment identifier
72
- log_folder: Directory for saving logs and models
73
- tensorboard_log: Tensorboard logging directory
74
- n_timesteps: Total training timesteps
75
- eval_freq: Frequency of evaluation (in timesteps)
76
- n_eval_episodes: Number of episodes for evaluation
77
- save_freq: Frequency of model saving (-1 to disable)
78
- hyperparams: Override hyperparameters
79
- env_kwargs: Environment creation arguments
80
- eval_env_kwargs: Evaluation environment arguments
81
- trained_agent: Path to pre-trained agent to load
82
- optimize_hyperparameters: Whether to run hyperparameter optimization
83
- storage: Optuna storage URL for hyperparameter optimization
84
- study_name: Optuna study name
85
- n_trials: Number of hyperparameter optimization trials
86
- max_total_trials: Maximum total trials across all processes
87
- n_jobs: Number of parallel jobs for optimization
88
- sampler: Optuna sampler ('tpe', 'random', 'cmaes')
89
- pruner: Optuna pruner ('median', 'successive_halving', 'hyperband')
90
- optimization_log_path: Path for optimization logs
91
- n_startup_trials: Number of startup trials for pruner
92
- n_evaluations: Number of evaluations per trial
93
- truncate_last_trajectory: Whether to truncate last trajectory
94
- uuid_str: Unique identifier string
95
- seed: Random seed
96
- log_interval: Logging interval during training
97
- save_replay_buffer: Whether to save replay buffer
98
- verbose: Verbosity level
99
- vec_env_type: Type of vectorized environment ('dummy', 'subproc')
100
- n_eval_envs: Number of parallel evaluation environments
101
- no_optim_plots: Whether to disable optimization plots
102
- device: Device to use ('auto', 'cpu', 'cuda', torch.device)
103
- config: Path to configuration file
104
- show_progress: Whether to show progress bar
105
- trial_id: Optional trial ID for hyperparameter optimization
106
"""
107
```
108
109
### Experiment Setup
110
111
Core methods for setting up and configuring experiments before training begins.
112
113
```python { .api }
114
def setup_experiment(self) -> BaseAlgorithm:
115
"""
116
Set up the experiment: load hyperparameters, create environments, and instantiate the model.
117
118
Returns:
119
BaseAlgorithm: Configured RL model ready for training
120
"""
121
122
def create_log_folder(self) -> None:
123
"""
124
Create log folder and set up logging directories.
125
"""
126
127
def create_callbacks(self) -> list[BaseCallback]:
128
"""
129
Create training callbacks based on configuration.
130
131
Returns:
132
list[BaseCallback]: List of configured callbacks
133
"""
134
```
135
136
### Environment Management
137
138
Methods for creating and managing training and evaluation environments with proper configuration and wrappers.
139
140
```python { .api }
141
def create_envs(self, n_envs: int, eval_env: bool = False) -> VecEnv:
142
"""
143
Create vectorized environments for training or evaluation.
144
145
Parameters:
146
- n_envs: Number of parallel environments
147
- eval_env: Whether this is for evaluation
148
149
Returns:
150
VecEnv: Configured vectorized environment
151
"""
152
153
def get_env_kwargs(self) -> dict[str, Any]:
154
"""
155
Get environment creation keyword arguments.
156
157
Returns:
158
dict: Environment kwargs
159
"""
160
```
161
162
### Model Creation and Loading
163
164
Methods for creating new models or loading pre-trained models with proper configuration.
165
166
```python { .api }
167
def create_model(self) -> BaseAlgorithm:
168
"""
169
Create a new RL model with loaded hyperparameters.
170
171
Returns:
172
BaseAlgorithm: Configured RL model
173
"""
174
175
def load_trained_model(self) -> BaseAlgorithm:
176
"""
177
Load a pre-trained model.
178
179
Returns:
180
BaseAlgorithm: Loaded RL model
181
"""
182
```
183
184
### Training and Learning
185
186
Methods for executing the training process with proper monitoring and checkpointing.
187
188
```python { .api }
189
def learn(self, model: BaseAlgorithm) -> None:
190
"""
191
Train the model with configured parameters and callbacks.
192
193
Parameters:
194
- model: RL model to train
195
"""
196
197
def save_trained_model(self, model: BaseAlgorithm) -> None:
198
"""
199
Save the trained model and associated files.
200
201
Parameters:
202
- model: Trained RL model to save
203
"""
204
```
205
206
### Hyperparameter Optimization
207
208
Methods for running hyperparameter optimization using Optuna with distributed training support.
209
210
```python { .api }
211
def hyperparameters_optimization(self) -> None:
212
"""
213
Run hyperparameter optimization using Optuna.
214
Supports distributed optimization across multiple processes.
215
"""
216
217
def objective(self, trial: optuna.Trial) -> float:
218
"""
219
Optuna objective function for hyperparameter optimization.
220
221
Parameters:
222
- trial: Optuna trial object
223
224
Returns:
225
float: Trial objective value (reward)
226
"""
227
```
228
229
### Configuration and Setup Methods
230
231
Methods for reading, loading, and preprocessing hyperparameters and configuration files.
232
233
```python { .api }
234
def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
235
"""
236
Read hyperparameters from YAML configuration files.
237
238
Returns:
239
tuple[dict[str, Any], dict[str, Any]]: (hyperparams, saved_hyperparams)
240
"""
241
242
def load_trial(self, trial_id: int) -> None:
243
"""
244
Load a specific Optuna trial configuration.
245
246
Parameters:
247
- trial_id: ID of the trial to load
248
"""
249
250
def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:
251
"""
252
Save configuration and hyperparameters to log directory.
253
254
Parameters:
255
- saved_hyperparams: Hyperparameters to save
256
"""
257
```
258
259
### Preprocessing Methods
260
261
Internal methods for preprocessing hyperparameters and configuration before training.
262
263
```python { .api }
264
@staticmethod
265
def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:
266
"""
267
Preprocess learning rate and other parameter schedules.
268
269
Parameters:
270
- hyperparams: Raw hyperparameters
271
272
Returns:
273
dict[str, Any]: Processed hyperparameters with schedule objects
274
"""
275
276
def _preprocess_normalization(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
277
"""
278
Preprocess VecNormalize parameters.
279
280
Parameters:
281
- hyperparams: Raw hyperparameters
282
283
Returns:
284
dict[str, Any]: Processed hyperparameters with normalization config
285
"""
286
287
def _preprocess_hyperparams(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
288
"""
289
Preprocess all hyperparameters before model creation.
290
291
Parameters:
292
- hyperparams: Raw hyperparameters
293
294
Returns:
295
dict[str, Any]: Fully processed hyperparameters
296
"""
297
298
def _preprocess_action_noise(self, hyperparams: dict[str, Any]) -> dict[str, Any]:
299
"""
300
Preprocess action noise parameters for algorithms that support it.
301
302
Parameters:
303
- hyperparams: Raw hyperparameters
304
305
Returns:
306
dict[str, Any]: Processed hyperparameters with action noise objects
307
"""
308
```
309
310
### Environment and Model Management
311
312
Methods for environment creation, model loading, and related utilities.
313
314
```python { .api }
315
def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
316
"""
317
Apply VecNormalize wrapper if specified in hyperparameters.
318
319
Parameters:
320
- env: Vector environment
321
- eval_env: Whether this is an evaluation environment
322
323
Returns:
324
VecEnv: Potentially normalized environment
325
"""
326
327
def _load_pretrained_agent(self, hyperparams: dict[str, Any], env: VecEnv) -> BaseAlgorithm:
328
"""
329
Load a pretrained agent for transfer learning or continued training.
330
331
Parameters:
332
- hyperparams: Model hyperparameters
333
- env: Training environment
334
335
Returns:
336
BaseAlgorithm: Loaded pretrained model
337
"""
338
```
339
340
### Optuna Integration
341
342
Methods for creating Optuna samplers and pruners for hyperparameter optimization.
343
344
```python { .api }
345
def _create_sampler(self, sampler_method: str) -> BaseSampler:
346
"""
347
Create Optuna sampler for hyperparameter optimization.
348
349
Parameters:
350
- sampler_method: Sampler type ("tpe", "random", "cmaes")
351
352
Returns:
353
BaseSampler: Configured Optuna sampler
354
"""
355
356
def _create_pruner(self, pruner_method: str) -> BasePruner:
357
"""
358
Create Optuna pruner for early stopping of unpromising trials.
359
360
Parameters:
361
- pruner_method: Pruner type ("median", "successive_halving", "nop")
362
363
Returns:
364
BasePruner: Configured Optuna pruner
365
"""
366
```
367
368
### Environment Detection Utilities
369
370
Static methods for detecting specific environment types and applying appropriate configurations.
371
372
```python { .api }
373
@staticmethod
374
def entry_point(env_id: str) -> str:
375
"""
376
Get the entry point for a given environment ID.
377
378
Parameters:
379
- env_id: Environment identifier
380
381
Returns:
382
str: Entry point string
383
"""
384
385
@staticmethod
386
def is_atari(env_id: str) -> bool:
387
"""
388
Check if environment is an Atari environment.
389
390
Parameters:
391
- env_id: Environment identifier
392
393
Returns:
394
bool: True if Atari environment
395
"""
396
397
@staticmethod
398
def is_minigrid(env_id: str) -> bool:
399
"""
400
Check if environment is a MiniGrid environment.
401
402
Parameters:
403
- env_id: Environment identifier
404
405
Returns:
406
bool: True if MiniGrid environment
407
"""
408
409
@staticmethod
410
def is_bullet(env_id: str) -> bool:
411
"""
412
Check if environment is a PyBullet environment.
413
414
Parameters:
415
- env_id: Environment identifier
416
417
Returns:
418
bool: True if PyBullet environment
419
"""
420
421
@staticmethod
422
def is_robotics_env(env_id: str) -> bool:
423
"""
424
Check if environment is a robotics environment.
425
426
Parameters:
427
- env_id: Environment identifier
428
429
Returns:
430
bool: True if robotics environment
431
"""
432
433
@staticmethod
434
def is_panda_gym(env_id: str) -> bool:
435
"""
436
Check if environment is a Panda Gym environment.
437
438
Parameters:
439
- env_id: Environment identifier
440
441
Returns:
442
bool: True if Panda Gym environment
443
"""
444
```
445
446
## Usage Examples
447
448
### Basic Training Setup
449
450
```python
451
import argparse
452
from rl_zoo3.exp_manager import ExperimentManager
453
454
# Create arguments (typically from command line)
455
args = argparse.Namespace(
456
algo='ppo',
457
env='CartPole-v1',
458
n_timesteps=10000,
459
eval_freq=1000,
460
n_eval_episodes=5,
461
save_freq=-1,
462
verbose=1,
463
seed=42
464
)
465
466
# Create experiment manager
467
exp_manager = ExperimentManager(
468
args=args,
469
algo='ppo',
470
env_id='CartPole-v1',
471
log_folder='./logs',
472
n_timesteps=10000,
473
eval_freq=1000,
474
seed=42
475
)
476
477
# Setup and train
478
model = exp_manager.setup_experiment()
479
exp_manager.learn(model)
480
exp_manager.save_trained_model(model)
481
```
482
483
### Advanced Training with Custom Configuration
484
485
```python
486
import argparse
487
from rl_zoo3.exp_manager import ExperimentManager
488
489
# Advanced configuration
490
args = argparse.Namespace(
491
algo='sac',
492
env='Pendulum-v1',
493
n_timesteps=50000,
494
eval_freq=5000,
495
n_eval_episodes=10,
496
save_freq=10000,
497
verbose=1,
498
seed=123,
499
tensorboard_log='./tb_logs',
500
vec_env_type='subproc',
501
n_envs=4
502
)
503
504
# Custom hyperparameters
505
custom_hyperparams = {
506
'learning_rate': 0.0003,
507
'buffer_size': 50000,
508
'batch_size': 64,
509
'tau': 0.02,
510
'gamma': 0.98
511
}
512
513
# Custom environment kwargs
514
env_kwargs = {
515
'render_mode': None,
516
'max_episode_steps': 200
517
}
518
519
# Create experiment manager with custom settings
520
exp_manager = ExperimentManager(
521
args=args,
522
algo='sac',
523
env_id='Pendulum-v1',
524
log_folder='./logs',
525
tensorboard_log='./tb_logs',
526
n_timesteps=50000,
527
eval_freq=5000,
528
n_eval_episodes=10,
529
save_freq=10000,
530
hyperparams=custom_hyperparams,
531
env_kwargs=env_kwargs,
532
vec_env_type='subproc',
533
n_envs=4,
534
seed=123,
535
show_progress=True
536
)
537
538
# Setup and train
539
model = exp_manager.setup_experiment()
540
exp_manager.learn(model)
541
exp_manager.save_trained_model(model)
542
```
543
544
### Hyperparameter Optimization
545
546
```python
547
import argparse
548
from rl_zoo3.exp_manager import ExperimentManager
549
550
# Setup for hyperparameter optimization
551
args = argparse.Namespace(
552
algo='ppo',
553
env='CartPole-v1',
554
n_timesteps=10000,
555
eval_freq=2000,
556
n_eval_episodes=5,
557
verbose=0, # Reduce verbosity for optimization
558
seed=42
559
)
560
561
# Create experiment manager for optimization
562
exp_manager = ExperimentManager(
563
args=args,
564
algo='ppo',
565
env_id='CartPole-v1',
566
log_folder='./optim_logs',
567
n_timesteps=10000,
568
eval_freq=2000,
569
optimize_hyperparameters=True,
570
n_trials=50,
571
n_jobs=2,
572
sampler='tpe',
573
pruner='median',
574
study_name='ppo_cartpole_optimization',
575
seed=42
576
)
577
578
# Run hyperparameter optimization
579
exp_manager.hyperparameters_optimization()
580
```
581
582
### Loading and Continuing Training
583
584
```python
585
import argparse
586
from rl_zoo3.exp_manager import ExperimentManager
587
588
# Setup for loading pre-trained model
589
args = argparse.Namespace(
590
algo='ppo',
591
env='CartPole-v1',
592
n_timesteps=20000, # Additional training steps
593
eval_freq=1000,
594
verbose=1,
595
seed=42
596
)
597
598
# Create experiment manager with trained agent
599
exp_manager = ExperimentManager(
600
args=args,
601
algo='ppo',
602
env_id='CartPole-v1',
603
log_folder='./logs',
604
trained_agent='./logs/ppo/CartPole-v1_1/best_model.zip',
605
n_timesteps=20000,
606
eval_freq=1000,
607
seed=42
608
)
609
610
# Load model and continue training
611
model = exp_manager.setup_experiment() # This will load the trained agent
612
exp_manager.learn(model) # Continue training for additional timesteps
613
exp_manager.save_trained_model(model)
614
```