A Training Framework for Stable Baselines3 Reinforcement Learning Agents
npx @tessl/cli install tessl/pypi-rl-zoo3@2.7.00
# RL Zoo3
1
2
RL Baselines3 Zoo is a comprehensive training framework for reinforcement learning agents using Stable Baselines3. It provides scripts for training, evaluating, and tuning hyperparameters for RL agents, along with a collection of pre-tuned hyperparameters for common environments and algorithms.
3
4
## Package Information
5
6
- **Package Name**: rl-zoo3
7
- **Language**: Python
8
- **Installation**: `pip install rl_zoo3`
9
10
## Core Imports
11
12
```python
13
import rl_zoo3
14
```
15
16
Common imports for core functionality:
17
18
```python
19
from rl_zoo3 import ALGOS, create_test_env, get_trained_models
20
from rl_zoo3.exp_manager import ExperimentManager
21
from rl_zoo3.utils import get_saved_hyperparams, linear_schedule
22
```
23
24
## Basic Usage
25
26
### Quick Training Example
27
28
```python
29
from rl_zoo3.train import train
30
import sys
31
32
# Set up command line arguments for training
33
sys.argv = [
34
'train.py',
35
'--algo', 'ppo',
36
'--env', 'CartPole-v1',
37
'--n-timesteps', '10000'
38
]
39
40
# Train an agent
41
train()
42
```
43
44
### Using ExperimentManager
45
46
```python
47
import argparse
48
from rl_zoo3.exp_manager import ExperimentManager
49
50
# Create arguments namespace
51
args = argparse.Namespace(
52
algo='ppo',
53
env='CartPole-v1',
54
n_timesteps=10000,
55
eval_freq=1000,
56
n_eval_episodes=5,
57
save_freq=-1,
58
verbose=1
59
)
60
61
# Create and setup experiment
62
exp_manager = ExperimentManager(
63
args=args,
64
algo='ppo',
65
env_id='CartPole-v1',
66
log_folder='./logs',
67
n_timesteps=10000
68
)
69
70
# Setup experiment (creates model and environment)
71
model = exp_manager.setup_experiment()
72
73
# Train the model
74
model.learn(total_timesteps=10000)
75
```
76
77
### Loading and Evaluating Trained Models
78
79
```python
80
from rl_zoo3 import get_trained_models, create_test_env
81
from rl_zoo3.enjoy import enjoy
82
import sys
83
84
# Get available trained models
85
trained_models = get_trained_models('./logs')
86
print("Available models:", trained_models)
87
88
# Set up command line arguments for evaluation
89
sys.argv = [
90
'enjoy.py',
91
'--algo', 'ppo',
92
'--env', 'CartPole-v1',
93
'--folder', './logs',
94
'--n-timesteps', '1000'
95
]
96
97
# Evaluate the trained agent
98
enjoy()
99
```
100
101
## Architecture
102
103
RL Zoo3 is built around several key components:
104
105
- **ExperimentManager**: Central orchestrator for training experiments, handles model creation, environment setup, hyperparameter loading, and training coordination
106
- **Algorithm Dictionary (ALGOS)**: Maps algorithm names to their Stable Baselines3 classes, supporting A2C, PPO, SAC, TD3, DQN, and more
107
- **Utilities**: Core functions for environment creation, model loading, hyperparameter management, and file operations
108
- **Callbacks**: Custom training callbacks for evaluation, hyperparameter optimization, and logging
109
- **Wrappers**: Environment wrappers for observation processing, reward modification, and training optimization
110
- **Plotting**: Visualization tools for training curves, evaluation results, and performance analysis
111
112
The framework integrates with Optuna for hyperparameter optimization, HuggingFace Hub for model sharing, and supports multiple environment libraries including OpenAI Gym, Atari, MuJoCo, and PyBullet.
113
114
## Capabilities
115
116
### Core Utilities
117
118
Essential utilities for working with RL environments, models, and hyperparameters. Includes algorithm mapping, environment creation, model loading, hyperparameter management, and scheduling functions.
119
120
```python { .api }
121
ALGOS: dict[str, type[BaseAlgorithm]]
122
123
def create_test_env(
124
env_id: str,
125
n_envs: int = 1,
126
stats_path: Optional[str] = None,
127
seed: int = 0,
128
log_dir: Optional[str] = None,
129
should_render: bool = True,
130
hyperparams: Optional[dict[str, Any]] = None,
131
env_kwargs: Optional[dict[str, Any]] = None,
132
vec_env_cls: Optional[type[VecEnv]] = None,
133
vec_env_kwargs: Optional[dict[str, Any]] = None
134
) -> VecEnv: ...
135
136
def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]: ...
137
138
def get_saved_hyperparams(
139
stats_path: str,
140
norm_reward: bool = False,
141
test_mode: bool = False
142
) -> tuple[dict[str, Any], str]: ...
143
144
def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule: ...
145
```
146
147
[Core Utilities](./core-utilities.md)
148
149
### Experiment Management
150
151
Comprehensive experiment orchestration through the ExperimentManager class, handling training workflows, hyperparameter optimization, environment setup, and model coordination.
152
153
```python { .api }
154
class ExperimentManager:
155
def __init__(
156
self,
157
args: argparse.Namespace,
158
algo: str,
159
env_id: str,
160
log_folder: str,
161
tensorboard_log: str = "",
162
n_timesteps: int = 0,
163
eval_freq: int = 10000,
164
n_eval_episodes: int = 5,
165
save_freq: int = -1,
166
hyperparams: Optional[dict[str, Any]] = None,
167
env_kwargs: Optional[dict[str, Any]] = None,
168
**kwargs
169
): ...
170
171
def setup_experiment(self) -> BaseAlgorithm: ...
172
def learn(self, model: BaseAlgorithm) -> None: ...
173
def save_trained_model(self, model: BaseAlgorithm) -> None: ...
174
```
175
176
[Experiment Management](./experiment-management.md)
177
178
### Training Callbacks
179
180
Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. Includes specialized callbacks for Optuna trials, VecNormalize saving, and parallel training.
181
182
```python { .api }
183
class TrialEvalCallback(EvalCallback):
184
def __init__(
185
self,
186
eval_env: VecEnv,
187
trial: optuna.Trial,
188
n_eval_episodes: int = 5,
189
eval_freq: int = 10000,
190
**kwargs
191
): ...
192
193
class SaveVecNormalizeCallback(BaseCallback):
194
def __init__(self, save_freq: int, save_path: str, **kwargs): ...
195
196
class ParallelTrainCallback(BaseCallback): ...
197
class RawStatisticsCallback(BaseCallback): ...
198
```
199
200
[Training Callbacks](./callbacks.md)
201
202
### Environment Wrappers
203
204
Custom Gymnasium environment wrappers for observation processing, reward modification, action manipulation, and training optimization. Includes wrappers for success truncation, action noise, history tracking, and frame skipping.
205
206
```python { .api }
207
class TruncatedOnSuccessWrapper(gym.Wrapper):
208
def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): ...
209
210
class ActionNoiseWrapper(gym.Wrapper):
211
def __init__(self, env: gym.Env, noise_std: float = 0.1): ...
212
213
class HistoryWrapper(gym.Wrapper):
214
def __init__(self, env: gym.Env, horizon: int = 2): ...
215
216
class DelayedRewardWrapper(gym.Wrapper):
217
def __init__(self, env: gym.Env, delay: int = 10): ...
218
```
219
220
[Environment Wrappers](./wrappers.md)
221
222
### Hyperparameter Optimization
223
224
Hyperparameter sampling and optimization utilities using Optuna. Includes algorithm-specific parameter samplers and conversion functions for different RL algorithms.
225
226
```python { .api }
227
def sample_ppo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
228
def sample_sac_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
229
def sample_dqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> dict[str, Any]: ...
230
def convert_onpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...
231
def convert_offpolicy_params(sampled_params: dict[str, Any]) -> dict[str, Any]: ...
232
```
233
234
[Hyperparameter Optimization](./hyperparameter-optimization.md)
235
236
### Plotting and Visualization
237
238
Comprehensive plotting tools for training curves, evaluation results, and performance analysis. Includes functions for plotting from log files, training progress, and generating publication-quality plots.
239
240
```python { .api }
241
def plot_train(): ...
242
def plot_from_file(): ...
243
def all_plots(): ...
244
def normalize_score(score: np.ndarray, env_id: str) -> np.ndarray: ...
245
```
246
247
[Plotting and Visualization](./plotting.md)
248
249
### HuggingFace Hub Integration
250
251
Model sharing and loading through HuggingFace Hub integration. Includes functions for uploading trained models, downloading pre-trained models, and generating model cards.
252
253
```python { .api }
254
def package_to_hub(
255
model: BaseAlgorithm,
256
model_name: str,
257
repo_id: str,
258
commit_message: str = "Add model",
259
**kwargs
260
) -> str: ...
261
262
def download_from_hub(
263
repo_id: str,
264
filename: str,
265
**kwargs
266
) -> str: ...
267
268
def generate_model_card(
269
model: BaseAlgorithm,
270
env_id: str,
271
**kwargs
272
) -> str: ...
273
```
274
275
[HuggingFace Hub Integration](./hub-integration.md)
276
277
## Command Line Scripts
278
279
RL Zoo3 provides several command-line entry points:
280
281
```bash
282
# Main CLI entry point
283
rl_zoo3 train --algo ppo --env CartPole-v1
284
rl_zoo3 enjoy --algo ppo --env CartPole-v1 --folder logs/
285
rl_zoo3 plot_train --log-dir logs/
286
rl_zoo3 plot_from_file --log-dir logs/
287
rl_zoo3 all_plots --log-dir logs/
288
289
# Direct script execution
290
python -m rl_zoo3.train --algo ppo --env CartPole-v1
291
python -m rl_zoo3.enjoy --algo ppo --env CartPole-v1 --folder logs/
292
```
293
294
## Supported Algorithms
295
296
RL Zoo3 supports the following reinforcement learning algorithms:
297
298
- **A2C**: Advantage Actor-Critic
299
- **DDPG**: Deep Deterministic Policy Gradient
300
- **DQN**: Deep Q-Network
301
- **PPO**: Proximal Policy Optimization
302
- **SAC**: Soft Actor-Critic
303
- **TD3**: Twin Delayed Deep Deterministic Policy Gradient
304
- **ARS**: Augmented Random Search
305
- **CrossQ**: CrossQ algorithm
306
- **QRDQN**: Quantile Regression DQN
307
- **TQC**: Truncated Quantile Critics
308
- **TRPO**: Trust Region Policy Optimization
309
- **PPO_LSTM**: PPO with LSTM policy
310
311
## Types
312
313
```python { .api }
314
from typing import Any, Callable, Optional, Union
315
from stable_baselines3.common.base_class import BaseAlgorithm
316
from stable_baselines3.common.callbacks import BaseCallback
317
from stable_baselines3.common.vec_env import VecEnv
318
import gymnasium as gym
319
import optuna
320
import torch as th
321
322
SimpleLinearSchedule = type # Linear parameter scheduling class
323
StoreDict = type # Argparse action for storing dict parameters
324
325
# Common type aliases
326
EnvironmentName = str
327
ModelName = str
328
HyperparamDict = dict[str, Any]
329
CallbackList = list[BaseCallback]
330
WrapperClass = Callable[[gym.Env], gym.Env]
331
```