or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md

index.mddocs/

0

# RL Zoo3

1

2

RL Baselines3 Zoo is a comprehensive training framework for reinforcement learning agents using Stable Baselines3. It provides scripts for training, evaluating, and tuning hyperparameters for RL agents, along with a collection of pre-tuned hyperparameters for common environments and algorithms.

3

4

## Package Information

5

6

- **Package Name**: rl-zoo3

7

- **Language**: Python

8

- **Installation**: `pip install rl_zoo3`

9

10

## Core Imports

11

12

```python

13

import rl_zoo3

14

```

15

16

Common imports for core functionality:

17

18

```python

19

from rl_zoo3 import ALGOS, create_test_env, get_trained_models

20

from rl_zoo3.exp_manager import ExperimentManager

21

from rl_zoo3.utils import get_saved_hyperparams, linear_schedule

22

```

23

24

## Basic Usage

25

26

### Quick Training Example

27

28

```python

29

from rl_zoo3.train import train

30

import sys

31

32

# Set up command line arguments for training

33

sys.argv = [

34

'train.py',

35

'--algo', 'ppo',

36

'--env', 'CartPole-v1',

37

'--n-timesteps', '10000'

38

]

39

40

# Train an agent

41

train()

42

```

43

44

### Using ExperimentManager

45

46

```python

47

import argparse

48

from rl_zoo3.exp_manager import ExperimentManager

49

50

# Create arguments namespace

51

args = argparse.Namespace(

52

algo='ppo',

53

env='CartPole-v1',

54

n_timesteps=10000,

55

eval_freq=1000,

56

n_eval_episodes=5,

57

save_freq=-1,

58

verbose=1

59

)

60

61

# Create and setup experiment

62

exp_manager = ExperimentManager(

63

args=args,

64

algo='ppo',

65

env_id='CartPole-v1',

66

log_folder='./logs',

67

n_timesteps=10000

68

)

69

70

# Setup experiment (creates model and environment)

71

model = exp_manager.setup_experiment()

72

73

# Train the model

74

model.learn(total_timesteps=10000)

75

```

76

77

### Loading and Evaluating Trained Models

78

79

```python

80

from rl_zoo3 import get_trained_models, create_test_env

81

from rl_zoo3.enjoy import enjoy

82

import sys

83

84

# Get available trained models

85

trained_models = get_trained_models('./logs')

86

print("Available models:", trained_models)

87

88

# Set up command line arguments for evaluation

89

sys.argv = [

90

'enjoy.py',

91

'--algo', 'ppo',

92

'--env', 'CartPole-v1',

93

'--folder', './logs',

94

'--n-timesteps', '1000'

95

]

96

97

# Evaluate the trained agent

98

enjoy()

99

```

100

101

## Architecture

102

103

RL Zoo3 is built around several key components:

104

105

- **ExperimentManager**: Central orchestrator for training experiments, handles model creation, environment setup, hyperparameter loading, and training coordination

106

- **Algorithm Dictionary (ALGOS)**: Maps algorithm names to their Stable Baselines3 classes, supporting A2C, PPO, SAC, TD3, DQN, and more

107

- **Utilities**: Core functions for environment creation, model loading, hyperparameter management, and file operations

108

- **Callbacks**: Custom training callbacks for evaluation, hyperparameter optimization, and logging

109

- **Wrappers**: Environment wrappers for observation processing, reward modification, and training optimization

110

- **Plotting**: Visualization tools for training curves, evaluation results, and performance analysis

111

112

The framework integrates with Optuna for hyperparameter optimization, HuggingFace Hub for model sharing, and supports multiple environment libraries including OpenAI Gym, Atari, MuJoCo, and PyBullet.

113

114

## Capabilities

115

116

### Core Utilities

117

118

Essential utilities for working with RL environments, models, and hyperparameters. Includes algorithm mapping, environment creation, model loading, hyperparameter management, and scheduling functions.

119

120

```python { .api }

121

ALGOS: dict[str, type[BaseAlgorithm]]

122

123

def create_test_env(

124

env_id: str,

125

n_envs: int = 1,

126

stats_path: Optional[str] = None,

127

seed: int = 0,

128

log_dir: Optional[str] = None,

129

should_render: bool = True,

130

hyperparams: Optional[dict[str, Any]] = None,

131

env_kwargs: Optional[dict[str, Any]] = None,

132

vec_env_cls: Optional[type[VecEnv]] = None,

133

vec_env_kwargs: Optional[dict[str, Any]] = None

134

) -> VecEnv: ...

135

136

def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]: ...

137

138

def get_saved_hyperparams(

139

stats_path: str,

140

norm_reward: bool = False,

141

test_mode: bool = False

142

) -> tuple[dict[str, Any], str]: ...

143

144

def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule: ...

145

```

146

147

[Core Utilities](./core-utilities.md)

148

149

### Experiment Management

150

151

Comprehensive experiment orchestration through the ExperimentManager class, handling training workflows, hyperparameter optimization, environment setup, and model coordination.

152

153

```python { .api }

154

class ExperimentManager:

155

def __init__(

156

self,

157

args: argparse.Namespace,

158

algo: str,

159

env_id: str,

160

log_folder: str,

161

tensorboard_log: str = "",

162

n_timesteps: int = 0,

163

eval_freq: int = 10000,

164

n_eval_episodes: int = 5,

165

save_freq: int = -1,

166

hyperparams: Optional[dict[str, Any]] = None,

167

env_kwargs: Optional[dict[str, Any]] = None,

168

**kwargs

169

): ...

170

171

def setup_experiment(self) -> BaseAlgorithm: ...

172

def learn(self, model: BaseAlgorithm) -> None: ...

173

def save_trained_model(self, model: BaseAlgorithm) -> None: ...

174

```

175

176

[Experiment Management](./experiment-management.md)

177

178

### Training Callbacks

179

180

Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. Includes specialized callbacks for Optuna trials, VecNormalize saving, and parallel training.

181

182

```python { .api }

183

class TrialEvalCallback(EvalCallback):

184

def __init__(

185

self,

186

eval_env: VecEnv,

187

trial: optuna.Trial,

188

n_eval_episodes: int = 5,

189

eval_freq: int = 10000,

190

**kwargs

191

): ...

192

193

class SaveVecNormalizeCallback(BaseCallback):

194

def __init__(self, save_freq: int, save_path: str, **kwargs): ...

195

196

class ParallelTrainCallback(BaseCallback): ...

197

class RawStatisticsCallback(BaseCallback): ...

198

```

199

200

[Training Callbacks](./callbacks.md)

201

202

### Environment Wrappers

203

204

Custom Gymnasium environment wrappers for observation processing, reward modification, action manipulation, and training optimization. Includes wrappers for success truncation, action noise, history tracking, and frame skipping.

205

206

```python { .api }

207

class TruncatedOnSuccessWrapper(gym.Wrapper):

208

def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): ...

209

210

class ActionNoiseWrapper(gym.Wrapper):

211

def __init__(self, env: gym.Env, noise_std: float = 0.1): ...

212

213

class HistoryWrapper(gym.Wrapper):

214

def __init__(self, env: gym.Env, horizon: int = 2): ...

215

216

class DelayedRewardWrapper(gym.Wrapper):

217

def __init__(self, env: gym.Env, delay: int = 10): ...

218

```

219

220

[Environment Wrappers](./wrappers.md)

221

222

### Hyperparameter Optimization

223

224

Hyperparameter sampling and optimization utilities using Optuna. Includes algorithm-specific parameter samplers and conversion functions for different RL algorithms.

225

226

```python { .api }

227

def sample_ppo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...

228

def sample_sac_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...

229

def sample_dqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...

230

def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...

231

def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...

232

```

233

234

[Hyperparameter Optimization](./hyperparameter-optimization.md)

235

236

### Plotting and Visualization

237

238

Comprehensive plotting tools for training curves, evaluation results, and performance analysis. Includes functions for plotting from log files, training progress, and generating publication-quality plots.

239

240

```python { .api }

241

def plot_train(): ...

242

def plot_from_file(): ...

243

def all_plots(): ...

244

def normalize_score(score: np.ndarray, env_id: str) -> np.ndarray: ...

245

```

246

247

[Plotting and Visualization](./plotting.md)

248

249

### HuggingFace Hub Integration

250

251

Model sharing and loading through HuggingFace Hub integration. Includes functions for uploading trained models, downloading pre-trained models, and generating model cards.

252

253

```python { .api }

254

def package_to_hub(

255

model: BaseAlgorithm,

256

model_name: str,

257

repo_id: str,

258

commit_message: str = "Add model",

259

**kwargs

260

) -> str: ...

261

262

def download_from_hub(

263

repo_id: str,

264

filename: str,

265

**kwargs

266

) -> str: ...

267

268

def generate_model_card(

269

model: BaseAlgorithm,

270

env_id: str,

271

**kwargs

272

) -> str: ...

273

```

274

275

[HuggingFace Hub Integration](./hub-integration.md)

276

277

## Command Line Scripts

278

279

RL Zoo3 provides several command-line entry points:

280

281

```bash

282

# Main CLI entry point

283

rl_zoo3 train --algo ppo --env CartPole-v1

284

rl_zoo3 enjoy --algo ppo --env CartPole-v1 --folder logs/

285

rl_zoo3 plot_train --log-dir logs/

286

rl_zoo3 plot_from_file --log-dir logs/

287

rl_zoo3 all_plots --log-dir logs/

288

289

# Direct script execution

290

python -m rl_zoo3.train --algo ppo --env CartPole-v1

291

python -m rl_zoo3.enjoy --algo ppo --env CartPole-v1 --folder logs/

292

```

293

294

## Supported Algorithms

295

296

RL Zoo3 supports the following reinforcement learning algorithms:

297

298

- **A2C**: Advantage Actor-Critic

299

- **DDPG**: Deep Deterministic Policy Gradient

300

- **DQN**: Deep Q-Network

301

- **PPO**: Proximal Policy Optimization

302

- **SAC**: Soft Actor-Critic

303

- **TD3**: Twin Delayed Deep Deterministic Policy Gradient

304

- **ARS**: Augmented Random Search

305

- **CrossQ**: CrossQ algorithm

306

- **QRDQN**: Quantile Regression DQN

307

- **TQC**: Truncated Quantile Critics

308

- **TRPO**: Trust Region Policy Optimization

309

- **PPO_LSTM**: PPO with LSTM policy

310

311

## Types

312

313

```python { .api }

314

from typing import Any, Callable, Optional, Union

315

from stable_baselines3.common.base_class import BaseAlgorithm

316

from stable_baselines3.common.callbacks import BaseCallback

317

from stable_baselines3.common.vec_env import VecEnv

318

import gymnasium as gym

319

import optuna

320

import torch as th

321

322

SimpleLinearSchedule = type # Linear parameter scheduling class

323

StoreDict = type # Argparse action for storing dict parameters

324

325

# Common type aliases

326

EnvironmentName = str

327

ModelName = str

328

HyperparamDict = dict[str, Any]

329

CallbackList = list[BaseCallback]

330

WrapperClass = Callable[[gym.Env], gym.Env]

331

```