or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md

experiment-management.mddocs/

0

# Experiment Management

1

2

Comprehensive experiment orchestration through the ExperimentManager class, which handles training workflows, hyperparameter optimization, environment setup, and model coordination. This is the central component that ties together all aspects of RL training experiments.

3

4

## Core Imports

5

6

```python

7

from rl_zoo3.exp_manager import ExperimentManager

8

import argparse

9

from typing import Optional, Any

10

```

11

12

## Capabilities

13

14

### ExperimentManager Class

15

16

The main class for managing RL experiments, from initial setup through training and evaluation. Handles hyperparameter loading, environment creation, model instantiation, and training coordination.

17

18

```python { .api }

19

class ExperimentManager:

20

"""

21

Experiment manager: read the hyperparameters,

22

preprocess them, create the environment and the RL model.

23

"""

24

25

def __init__(

26

self,

27

args: argparse.Namespace,

28

algo: str,

29

env_id: str,

30

log_folder: str,

31

tensorboard_log: str = "",

32

n_timesteps: int = 0,

33

eval_freq: int = 10000,

34

n_eval_episodes: int = 5,

35

save_freq: int = -1,

36

hyperparams: Optional[dict[str, Any]] = None,

37

env_kwargs: Optional[dict[str, Any]] = None,

38

eval_env_kwargs: Optional[dict[str, Any]] = None,

39

trained_agent: str = "",

40

optimize_hyperparameters: bool = False,

41

storage: Optional[str] = None,

42

study_name: Optional[str] = None,

43

n_trials: int = 1,

44

max_total_trials: Optional[int] = None,

45

n_jobs: int = 1,

46

sampler: str = "tpe",

47

pruner: str = "median",

48

optimization_log_path: Optional[str] = None,

49

n_startup_trials: int = 0,

50

n_evaluations: int = 1,

51

truncate_last_trajectory: bool = False,

52

uuid_str: str = "",

53

seed: int = 0,

54

log_interval: int = 0,

55

save_replay_buffer: bool = False,

56

verbose: int = 1,

57

vec_env_type: str = "dummy",

58

n_eval_envs: int = 1,

59

no_optim_plots: bool = False,

60

device: Union[th.device, str] = "auto",

61

config: Optional[str] = None,

62

show_progress: bool = False,

63

trial_id: Optional[int] = None

64

):

65

"""

66

Initialize ExperimentManager.

67

68

Parameters:

69

- args: Command line arguments namespace

70

- algo: Algorithm name (must be in ALGOS dict)

71

- env_id: Environment identifier

72

- log_folder: Directory for saving logs and models

73

- tensorboard_log: Tensorboard logging directory

74

- n_timesteps: Total training timesteps

75

- eval_freq: Frequency of evaluation (in timesteps)

76

- n_eval_episodes: Number of episodes for evaluation

77

- save_freq: Frequency of model saving (-1 to disable)

78

- hyperparams: Override hyperparameters

79

- env_kwargs: Environment creation arguments

80

- eval_env_kwargs: Evaluation environment arguments

81

- trained_agent: Path to pre-trained agent to load

82

- optimize_hyperparameters: Whether to run hyperparameter optimization

83

- storage: Optuna storage URL for hyperparameter optimization

84

- study_name: Optuna study name

85

- n_trials: Number of hyperparameter optimization trials

86

- max_total_trials: Maximum total trials across all processes

87

- n_jobs: Number of parallel jobs for optimization

88

- sampler: Optuna sampler ('tpe', 'random', 'cmaes')

89

- pruner: Optuna pruner ('median', 'successive_halving', 'hyperband')

90

- optimization_log_path: Path for optimization logs

91

- n_startup_trials: Number of startup trials for pruner

92

- n_evaluations: Number of evaluations per trial

93

- truncate_last_trajectory: Whether to truncate last trajectory

94

- uuid_str: Unique identifier string

95

- seed: Random seed

96

- log_interval: Logging interval during training

97

- save_replay_buffer: Whether to save replay buffer

98

- verbose: Verbosity level

99

- vec_env_type: Type of vectorized environment ('dummy', 'subproc')

100

- n_eval_envs: Number of parallel evaluation environments

101

- no_optim_plots: Whether to disable optimization plots

102

- device: Device to use ('auto', 'cpu', 'cuda', torch.device)

103

- config: Path to configuration file

104

- show_progress: Whether to show progress bar

105

- trial_id: Optional trial ID for hyperparameter optimization

106

"""

107

```

108

109

### Experiment Setup

110

111

Core methods for setting up and configuring experiments before training begins.

112

113

```python { .api }

114

def setup_experiment(self) -> BaseAlgorithm:

115

"""

116

Set up the experiment: load hyperparameters, create environments, and instantiate the model.

117

118

Returns:

119

BaseAlgorithm: Configured RL model ready for training

120

"""

121

122

def create_log_folder(self) -> None:

123

"""

124

Create log folder and set up logging directories.

125

"""

126

127

def create_callbacks(self) -> list[BaseCallback]:

128

"""

129

Create training callbacks based on configuration.

130

131

Returns:

132

list[BaseCallback]: List of configured callbacks

133

"""

134

```

135

136

### Environment Management

137

138

Methods for creating and managing training and evaluation environments with proper configuration and wrappers.

139

140

```python { .api }

141

def create_envs(self, n_envs: int, eval_env: bool = False) -> VecEnv:

142

"""

143

Create vectorized environments for training or evaluation.

144

145

Parameters:

146

- n_envs: Number of parallel environments

147

- eval_env: Whether this is for evaluation

148

149

Returns:

150

VecEnv: Configured vectorized environment

151

"""

152

153

def get_env_kwargs(self) -> dict[str, Any]:

154

"""

155

Get environment creation keyword arguments.

156

157

Returns:

158

dict: Environment kwargs

159

"""

160

```

161

162

### Model Creation and Loading

163

164

Methods for creating new models or loading pre-trained models with proper configuration.

165

166

```python { .api }

167

def create_model(self) -> BaseAlgorithm:

168

"""

169

Create a new RL model with loaded hyperparameters.

170

171

Returns:

172

BaseAlgorithm: Configured RL model

173

"""

174

175

def load_trained_model(self) -> BaseAlgorithm:

176

"""

177

Load a pre-trained model.

178

179

Returns:

180

BaseAlgorithm: Loaded RL model

181

"""

182

```

183

184

### Training and Learning

185

186

Methods for executing the training process with proper monitoring and checkpointing.

187

188

```python { .api }

189

def learn(self, model: BaseAlgorithm) -> None:

190

"""

191

Train the model with configured parameters and callbacks.

192

193

Parameters:

194

- model: RL model to train

195

"""

196

197

def save_trained_model(self, model: BaseAlgorithm) -> None:

198

"""

199

Save the trained model and associated files.

200

201

Parameters:

202

- model: Trained RL model to save

203

"""

204

```

205

206

### Hyperparameter Optimization

207

208

Methods for running hyperparameter optimization using Optuna with distributed training support.

209

210

```python { .api }

211

def hyperparameters_optimization(self) -> None:

212

"""

213

Run hyperparameter optimization using Optuna.

214

Supports distributed optimization across multiple processes.

215

"""

216

217

def objective(self, trial: optuna.Trial) -> float:

218

"""

219

Optuna objective function for hyperparameter optimization.

220

221

Parameters:

222

- trial: Optuna trial object

223

224

Returns:

225

float: Trial objective value (reward)

226

"""

227

```

228

229

### Configuration and Setup Methods

230

231

Methods for reading, loading, and preprocessing hyperparameters and configuration files.

232

233

```python { .api }

234

def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:

235

"""

236

Read hyperparameters from YAML configuration files.

237

238

Returns:

239

tuple[dict[str, Any], dict[str, Any]]: (hyperparams, saved_hyperparams)

240

"""

241

242

def load_trial(self, trial_id: int) -> None:

243

"""

244

Load a specific Optuna trial configuration.

245

246

Parameters:

247

- trial_id: ID of the trial to load

248

"""

249

250

def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:

251

"""

252

Save configuration and hyperparameters to log directory.

253

254

Parameters:

255

- saved_hyperparams: Hyperparameters to save

256

"""

257

```

258

259

### Preprocessing Methods

260

261

Internal methods for preprocessing hyperparameters and configuration before training.

262

263

```python { .api }

264

@staticmethod

265

def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:

266

"""

267

Preprocess learning rate and other parameter schedules.

268

269

Parameters:

270

- hyperparams: Raw hyperparameters

271

272

Returns:

273

dict[str, Any]: Processed hyperparameters with schedule objects

274

"""

275

276

def _preprocess_normalization(self, hyperparams: dict[str, Any]) -> dict[str, Any]:

277

"""

278

Preprocess VecNormalize parameters.

279

280

Parameters:

281

- hyperparams: Raw hyperparameters

282

283

Returns:

284

dict[str, Any]: Processed hyperparameters with normalization config

285

"""

286

287

def _preprocess_hyperparams(self, hyperparams: dict[str, Any]) -> dict[str, Any]:

288

"""

289

Preprocess all hyperparameters before model creation.

290

291

Parameters:

292

- hyperparams: Raw hyperparameters

293

294

Returns:

295

dict[str, Any]: Fully processed hyperparameters

296

"""

297

298

def _preprocess_action_noise(self, hyperparams: dict[str, Any]) -> dict[str, Any]:

299

"""

300

Preprocess action noise parameters for algorithms that support it.

301

302

Parameters:

303

- hyperparams: Raw hyperparameters

304

305

Returns:

306

dict[str, Any]: Processed hyperparameters with action noise objects

307

"""

308

```

309

310

### Environment and Model Management

311

312

Methods for environment creation, model loading, and related utilities.

313

314

```python { .api }

315

def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:

316

"""

317

Apply VecNormalize wrapper if specified in hyperparameters.

318

319

Parameters:

320

- env: Vector environment

321

- eval_env: Whether this is an evaluation environment

322

323

Returns:

324

VecEnv: Potentially normalized environment

325

"""

326

327

def _load_pretrained_agent(self, hyperparams: dict[str, Any], env: VecEnv) -> BaseAlgorithm:

328

"""

329

Load a pretrained agent for transfer learning or continued training.

330

331

Parameters:

332

- hyperparams: Model hyperparameters

333

- env: Training environment

334

335

Returns:

336

BaseAlgorithm: Loaded pretrained model

337

"""

338

```

339

340

### Optuna Integration

341

342

Methods for creating Optuna samplers and pruners for hyperparameter optimization.

343

344

```python { .api }

345

def _create_sampler(self, sampler_method: str) -> BaseSampler:

346

"""

347

Create Optuna sampler for hyperparameter optimization.

348

349

Parameters:

350

- sampler_method: Sampler type ("tpe", "random", "cmaes")

351

352

Returns:

353

BaseSampler: Configured Optuna sampler

354

"""

355

356

def _create_pruner(self, pruner_method: str) -> BasePruner:

357

"""

358

Create Optuna pruner for early stopping of unpromising trials.

359

360

Parameters:

361

- pruner_method: Pruner type ("median", "successive_halving", "nop")

362

363

Returns:

364

BasePruner: Configured Optuna pruner

365

"""

366

```

367

368

### Environment Detection Utilities

369

370

Static methods for detecting specific environment types and applying appropriate configurations.

371

372

```python { .api }

373

@staticmethod

374

def entry_point(env_id: str) -> str:

375

"""

376

Get the entry point for a given environment ID.

377

378

Parameters:

379

- env_id: Environment identifier

380

381

Returns:

382

str: Entry point string

383

"""

384

385

@staticmethod

386

def is_atari(env_id: str) -> bool:

387

"""

388

Check if environment is an Atari environment.

389

390

Parameters:

391

- env_id: Environment identifier

392

393

Returns:

394

bool: True if Atari environment

395

"""

396

397

@staticmethod

398

def is_minigrid(env_id: str) -> bool:

399

"""

400

Check if environment is a MiniGrid environment.

401

402

Parameters:

403

- env_id: Environment identifier

404

405

Returns:

406

bool: True if MiniGrid environment

407

"""

408

409

@staticmethod

410

def is_bullet(env_id: str) -> bool:

411

"""

412

Check if environment is a PyBullet environment.

413

414

Parameters:

415

- env_id: Environment identifier

416

417

Returns:

418

bool: True if PyBullet environment

419

"""

420

421

@staticmethod

422

def is_robotics_env(env_id: str) -> bool:

423

"""

424

Check if environment is a robotics environment.

425

426

Parameters:

427

- env_id: Environment identifier

428

429

Returns:

430

bool: True if robotics environment

431

"""

432

433

@staticmethod

434

def is_panda_gym(env_id: str) -> bool:

435

"""

436

Check if environment is a Panda Gym environment.

437

438

Parameters:

439

- env_id: Environment identifier

440

441

Returns:

442

bool: True if Panda Gym environment

443

"""

444

```

445

446

## Usage Examples

447

448

### Basic Training Setup

449

450

```python

451

import argparse

452

from rl_zoo3.exp_manager import ExperimentManager

453

454

# Create arguments (typically from command line)

455

args = argparse.Namespace(

456

algo='ppo',

457

env='CartPole-v1',

458

n_timesteps=10000,

459

eval_freq=1000,

460

n_eval_episodes=5,

461

save_freq=-1,

462

verbose=1,

463

seed=42

464

)

465

466

# Create experiment manager

467

exp_manager = ExperimentManager(

468

args=args,

469

algo='ppo',

470

env_id='CartPole-v1',

471

log_folder='./logs',

472

n_timesteps=10000,

473

eval_freq=1000,

474

seed=42

475

)

476

477

# Setup and train

478

model = exp_manager.setup_experiment()

479

exp_manager.learn(model)

480

exp_manager.save_trained_model(model)

481

```

482

483

### Advanced Training with Custom Configuration

484

485

```python

486

import argparse

487

from rl_zoo3.exp_manager import ExperimentManager

488

489

# Advanced configuration

490

args = argparse.Namespace(

491

algo='sac',

492

env='Pendulum-v1',

493

n_timesteps=50000,

494

eval_freq=5000,

495

n_eval_episodes=10,

496

save_freq=10000,

497

verbose=1,

498

seed=123,

499

tensorboard_log='./tb_logs',

500

vec_env_type='subproc',

501

n_envs=4

502

)

503

504

# Custom hyperparameters

505

custom_hyperparams = {

506

'learning_rate': 0.0003,

507

'buffer_size': 50000,

508

'batch_size': 64,

509

'tau': 0.02,

510

'gamma': 0.98

511

}

512

513

# Custom environment kwargs

514

env_kwargs = {

515

'render_mode': None,

516

'max_episode_steps': 200

517

}

518

519

# Create experiment manager with custom settings

520

exp_manager = ExperimentManager(

521

args=args,

522

algo='sac',

523

env_id='Pendulum-v1',

524

log_folder='./logs',

525

tensorboard_log='./tb_logs',

526

n_timesteps=50000,

527

eval_freq=5000,

528

n_eval_episodes=10,

529

save_freq=10000,

530

hyperparams=custom_hyperparams,

531

env_kwargs=env_kwargs,

532

vec_env_type='subproc',

533

n_envs=4,

534

seed=123,

535

show_progress=True

536

)

537

538

# Setup and train

539

model = exp_manager.setup_experiment()

540

exp_manager.learn(model)

541

exp_manager.save_trained_model(model)

542

```

543

544

### Hyperparameter Optimization

545

546

```python

547

import argparse

548

from rl_zoo3.exp_manager import ExperimentManager

549

550

# Setup for hyperparameter optimization

551

args = argparse.Namespace(

552

algo='ppo',

553

env='CartPole-v1',

554

n_timesteps=10000,

555

eval_freq=2000,

556

n_eval_episodes=5,

557

verbose=0, # Reduce verbosity for optimization

558

seed=42

559

)

560

561

# Create experiment manager for optimization

562

exp_manager = ExperimentManager(

563

args=args,

564

algo='ppo',

565

env_id='CartPole-v1',

566

log_folder='./optim_logs',

567

n_timesteps=10000,

568

eval_freq=2000,

569

optimize_hyperparameters=True,

570

n_trials=50,

571

n_jobs=2,

572

sampler='tpe',

573

pruner='median',

574

study_name='ppo_cartpole_optimization',

575

seed=42

576

)

577

578

# Run hyperparameter optimization

579

exp_manager.hyperparameters_optimization()

580

```

581

582

### Loading and Continuing Training

583

584

```python

585

import argparse

586

from rl_zoo3.exp_manager import ExperimentManager

587

588

# Setup for loading pre-trained model

589

args = argparse.Namespace(

590

algo='ppo',

591

env='CartPole-v1',

592

n_timesteps=20000, # Additional training steps

593

eval_freq=1000,

594

verbose=1,

595

seed=42

596

)

597

598

# Create experiment manager with trained agent

599

exp_manager = ExperimentManager(

600

args=args,

601

algo='ppo',

602

env_id='CartPole-v1',

603

log_folder='./logs',

604

trained_agent='./logs/ppo/CartPole-v1_1/best_model.zip',

605

n_timesteps=20000,

606

eval_freq=1000,

607

seed=42

608

)

609

610

# Load model and continue training

611

model = exp_manager.setup_experiment() # This will load the trained agent

612

exp_manager.learn(model) # Continue training for additional timesteps

613

exp_manager.save_trained_model(model)

614

```