0
# Core Utilities
1
2
Essential utilities for working with RL environments, models, and hyperparameters. These functions form the foundation of the RL Zoo3 framework, providing core capabilities for algorithm selection, environment creation, model management, and parameter scheduling.
3
4
## Core Imports
5
6
```python
7
from rl_zoo3 import ALGOS, create_test_env, get_trained_models, linear_schedule
8
from rl_zoo3.utils import (
9
get_model_path,
10
get_saved_hyperparams,
11
get_latest_run_id,
12
get_wrapper_class,
13
get_class_by_name,
14
flatten_dict_observations,
15
get_callback_list
16
)
17
```
18
19
## Capabilities
20
21
### Algorithm Dictionary
22
23
Central registry mapping algorithm names to their Stable Baselines3 classes, enabling dynamic algorithm selection and instantiation.
24
25
```python { .api }
26
ALGOS: dict[str, type[BaseAlgorithm]]
27
```
28
29
The ALGOS dictionary contains mappings for:
30
- `"a2c"`: A2C (Advantage Actor-Critic)
31
- `"ddpg"`: DDPG (Deep Deterministic Policy Gradient)
32
- `"dqn"`: DQN (Deep Q-Network)
33
- `"ppo"`: PPO (Proximal Policy Optimization)
34
- `"sac"`: SAC (Soft Actor-Critic)
35
- `"td3"`: TD3 (Twin Delayed Deep Deterministic Policy Gradient)
36
- `"ars"`: ARS (Augmented Random Search)
37
- `"crossq"`: CrossQ
38
- `"qrdqn"`: QRDQN (Quantile Regression DQN)
39
- `"tqc"`: TQC (Truncated Quantile Critics)
40
- `"trpo"`: TRPO (Trust Region Policy Optimization)
41
- `"ppo_lstm"`: RecurrentPPO (PPO with LSTM)
42
43
Usage example:
44
```python
45
from rl_zoo3 import ALGOS
46
from stable_baselines3.common.env_util import make_vec_env
47
48
# Get the PPO algorithm class
49
ppo_class = ALGOS["ppo"]
50
51
# Create environment and model
52
env = make_vec_env("CartPole-v1", n_envs=1)
53
model = ppo_class("MlpPolicy", env, verbose=1)
54
model.learn(total_timesteps=10000)
55
```
56
57
### Environment Creation
58
59
Creates vectorized test environments with proper wrappers, normalization, and configuration for evaluation and testing.
60
61
```python { .api }
62
def create_test_env(
63
env_id: str,
64
n_envs: int = 1,
65
stats_path: Optional[str] = None,
66
seed: int = 0,
67
log_dir: Optional[str] = None,
68
should_render: bool = True,
69
hyperparams: Optional[dict[str, Any]] = None,
70
env_kwargs: Optional[dict[str, Any]] = None,
71
vec_env_cls: Optional[type[VecEnv]] = None,
72
vec_env_kwargs: Optional[dict[str, Any]] = None
73
) -> VecEnv:
74
"""
75
Create a wrapped, monitored VecEnv for testing.
76
77
Parameters:
78
- env_id: Environment identifier (e.g., 'CartPole-v1')
79
- n_envs: Number of parallel environments
80
- stats_path: Path to VecNormalize statistics file
81
- seed: Random seed for reproducibility
82
- log_dir: Directory for logging environment interactions
83
- should_render: Whether to enable rendering (for PyBullet envs)
84
- hyperparams: Hyperparameters dict containing env_wrapper settings
85
- env_kwargs: Additional keyword arguments for environment creation
86
- vec_env_cls: VecEnv class constructor to use
87
- vec_env_kwargs: Keyword arguments for VecEnv constructor
88
89
Returns:
90
VecEnv: Configured vectorized environment ready for testing
91
"""
92
```
93
94
Usage example:
95
```python
96
from rl_zoo3 import create_test_env
97
98
# Create test environment
99
env = create_test_env(
100
env_id="CartPole-v1",
101
n_envs=4,
102
seed=42,
103
should_render=False
104
)
105
106
# Use with trained model
107
obs = env.reset()
108
for _ in range(1000):
109
action = env.action_space.sample() # Random actions for demo
110
obs, rewards, dones, infos = env.step(action)
111
```
112
113
### Model Discovery
114
115
Discovers and lists trained models from log directories, returning model paths and metadata for loading and evaluation.
116
117
```python { .api }
118
def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]:
119
"""
120
Get a dictionary of trained models from a log folder.
121
122
Parameters:
123
- log_folder: Path to the directory containing trained models
124
125
Returns:
126
dict: Dictionary mapping model names to (model_path, stats_path) tuples
127
"""
128
```
129
130
```python { .api }
131
def get_hf_trained_models(
132
organization: str = "sb3",
133
check_filename: bool = False
134
) -> dict[str, tuple[str, str]]:
135
"""
136
Get trained models from HuggingFace Hub.
137
138
Parameters:
139
- organization: HuggingFace organization name
140
- check_filename: Whether to validate model filenames
141
142
Returns:
143
dict: Dictionary mapping model names to (repo_id, filename) tuples
144
"""
145
```
146
147
Usage example:
148
```python
149
from rl_zoo3 import get_trained_models, get_hf_trained_models
150
151
# Get locally trained models
152
local_models = get_trained_models("./logs")
153
print("Local models:", local_models)
154
155
# Get models from HuggingFace Hub
156
hf_models = get_hf_trained_models(organization="sb3")
157
print("HF models:", list(hf_models.keys())[:5]) # Show first 5
158
```
159
160
### Run Management
161
162
Utilities for managing training runs and finding the latest run ID for continued training or evaluation.
163
164
```python { .api }
165
def get_latest_run_id(log_path: str, env_name: str) -> int:
166
"""
167
Get the latest run ID for a given environment.
168
169
Parameters:
170
- log_path: Path to the log directory
171
- env_name: Environment name
172
173
Returns:
174
int: Latest run ID (0-indexed)
175
"""
176
```
177
178
```python { .api }
179
def get_model_path(
180
exp_id: int,
181
folder: str,
182
algo: str,
183
env_name: str,
184
load_best: bool = False,
185
load_checkpoint: Optional[str] = None,
186
load_last_checkpoint: bool = False
187
) -> tuple[str, str, str]:
188
"""
189
Get the path to a trained model and related information.
190
191
Parameters:
192
- exp_id: Experiment ID (0 for latest)
193
- folder: Log folder path
194
- algo: Algorithm name
195
- env_name: Environment name
196
- load_best: Whether to load the best model
197
- load_checkpoint: Specific checkpoint to load (e.g., "100000")
198
- load_last_checkpoint: Whether to load the last checkpoint
199
200
Returns:
201
tuple[str, str, str]: (name_prefix, model_path, log_path)
202
"""
203
```
204
205
### Hyperparameter Management
206
207
Loading and managing saved hyperparameters and normalization statistics from trained models.
208
209
```python { .api }
210
def get_saved_hyperparams(
211
stats_path: str,
212
norm_reward: bool = False,
213
test_mode: bool = False
214
) -> tuple[dict[str, Any], str]:
215
"""
216
Load saved hyperparameters from a stats file.
217
218
Parameters:
219
- stats_path: Path to the stats.pkl file
220
- norm_reward: Whether reward normalization was used
221
- test_mode: Whether in test mode
222
223
Returns:
224
tuple: (hyperparams_dict, stats_path)
225
"""
226
```
227
228
Usage example:
229
```python
230
from rl_zoo3 import get_saved_hyperparams
231
232
# Load hyperparameters from trained model
233
hyperparams, stats_path = get_saved_hyperparams("./logs/ppo/CartPole-v1_1/")
234
print("Loaded hyperparams:", hyperparams)
235
```
236
237
### Parameter Scheduling
238
239
Linear scheduling functions for hyperparameters like learning rate that need to change during training.
240
241
```python { .api }
242
def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule:
243
"""
244
Create a linear schedule for a hyperparameter.
245
246
Parameters:
247
- initial_value: Initial value (float) or string representation
248
249
Returns:
250
SimpleLinearSchedule: Callable schedule object
251
"""
252
```
253
254
```python { .api }
255
class SimpleLinearSchedule:
256
"""
257
Linear parameter scheduling class.
258
"""
259
def __init__(self, initial_value: float): ...
260
def __call__(self, progress_remaining: float) -> float: ...
261
```
262
263
Usage example:
264
```python
265
from rl_zoo3 import linear_schedule, ALGOS
266
from stable_baselines3.common.env_util import make_vec_env
267
268
# Create linear learning rate schedule
269
lr_schedule = linear_schedule(0.001)
270
271
# Use with PPO
272
env = make_vec_env("CartPole-v1", n_envs=1)
273
model = ALGOS["ppo"](
274
"MlpPolicy",
275
env,
276
learning_rate=lr_schedule,
277
verbose=1
278
)
279
```
280
281
### Environment Wrapper Utilities
282
283
Utilities for extracting and applying environment wrappers from hyperparameter configurations.
284
285
```python { .api }
286
def get_wrapper_class(
287
hyperparams: dict[str, Any],
288
key: str = "env_wrapper"
289
) -> Optional[Callable[[gym.Env], gym.Env]]:
290
"""
291
Get one or more Gym environment wrapper class from hyperparams.
292
293
Parameters:
294
- hyperparams: Hyperparameters dictionary
295
- key: Key in hyperparams containing wrapper specification
296
297
Returns:
298
Optional wrapper class or wrapper chain function
299
"""
300
```
301
302
```python { .api }
303
def get_class_by_name(name: str) -> type:
304
"""
305
Dynamically import a class by its name.
306
307
Parameters:
308
- name: Full class name (e.g., 'stable_baselines3.PPO')
309
310
Returns:
311
type: The imported class
312
"""
313
```
314
315
```python { .api }
316
def flatten_dict_observations(env: gym.Env) -> gym.Env:
317
"""
318
Flatten dictionary observation spaces.
319
320
Parameters:
321
- env: Environment with Dict observation space
322
323
Returns:
324
gym.Env: Environment with flattened observation space
325
"""
326
```
327
328
### Callback Management
329
330
Utilities for creating and managing training callbacks from hyperparameter configurations.
331
332
```python { .api }
333
def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
334
"""
335
Get callback list from hyperparams.
336
337
Parameters:
338
- hyperparams: Hyperparameters dictionary containing callback specifications
339
340
Returns:
341
list[BaseCallback]: List of configured callbacks
342
"""
343
```
344
345
Usage example:
346
```python
347
from rl_zoo3.utils import get_callback_list
348
349
# Hyperparams with callbacks
350
hyperparams = {
351
"callback": "stable_baselines3.common.callbacks.CheckpointCallback",
352
"callback_kwargs": {"save_freq": 1000, "save_path": "./checkpoints/"}
353
}
354
355
# Get callback list
356
callbacks = get_callback_list(hyperparams)
357
print(f"Created {len(callbacks)} callbacks")
358
```
359
360
### Utility Classes
361
362
```python { .api }
363
class StoreDict(argparse.Action):
364
"""
365
Argparse action for storing dictionary parameters.
366
Converts key=value pairs to dictionary entries.
367
"""
368
def __call__(self, parser, namespace, values, option_string=None): ...
369
```
370
371
Usage example:
372
```python
373
import argparse
374
from rl_zoo3.utils import StoreDict
375
376
parser = argparse.ArgumentParser()
377
parser.add_argument("--env-kwargs", type=str, nargs="+", action=StoreDict)
378
args = parser.parse_args(["--env-kwargs", "render_mode=human", "max_steps=1000"])
379
print(args.env_kwargs) # {'render_mode': 'human', 'max_steps': '1000'}
380
```