0
# Training Callbacks
1
2
Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. These callbacks extend Stable Baselines3's callback system with specialized functionality for RL Zoo3's training workflows.
3
4
## Core Imports
5
6
```python
7
from rl_zoo3.callbacks import (
8
TrialEvalCallback,
9
SaveVecNormalizeCallback,
10
ParallelTrainCallback,
11
RawStatisticsCallback
12
)
13
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
14
import optuna
15
```
16
17
## Capabilities
18
19
### Trial Evaluation Callback
20
21
Specialized callback for Optuna hyperparameter optimization trials, providing evaluation and pruning functionality during optimization runs.
22
23
```python { .api }
24
class TrialEvalCallback(EvalCallback):
25
"""
26
Callback used for evaluating and reporting a trial during hyperparameter optimization.
27
Extends EvalCallback with Optuna trial integration.
28
"""
29
30
def __init__(
31
self,
32
eval_env: VecEnv,
33
trial: optuna.Trial,
34
n_eval_episodes: int = 5,
35
eval_freq: int = 10000,
36
deterministic: bool = True,
37
verbose: int = 0,
38
best_model_save_path: Optional[str] = None,
39
log_path: Optional[str] = None
40
) -> None:
41
"""
42
Initialize TrialEvalCallback.
43
44
Parameters:
45
- eval_env: Vectorized evaluation environment
46
- trial: Optuna trial object for reporting results
47
- n_eval_episodes: Number of episodes for each evaluation
48
- eval_freq: Frequency of evaluation (in timesteps)
49
- deterministic: Whether to use deterministic actions during evaluation
50
- verbose: Verbosity level
51
- best_model_save_path: Path to save best model
52
- log_path: Path for evaluation logs
53
"""
54
55
def _on_step(self) -> bool:
56
"""
57
Called at each training step. Performs evaluation and reports to Optuna.
58
59
Returns:
60
bool: Whether training should continue
61
"""
62
```
63
64
Usage example:
65
```python
66
import optuna
67
from rl_zoo3.callbacks import TrialEvalCallback
68
from rl_zoo3 import create_test_env
69
70
def objective(trial):
71
# Create evaluation environment
72
eval_env = create_test_env("CartPole-v1", n_envs=1)
73
74
# Create callback
75
eval_callback = TrialEvalCallback(
76
eval_env=eval_env,
77
trial=trial,
78
n_eval_episodes=10,
79
eval_freq=1000,
80
deterministic=True,
81
verbose=0
82
)
83
84
# Use callback in model training (simplified)
85
# model.learn(total_timesteps=10000, callback=eval_callback)
86
87
return eval_callback.best_mean_reward
88
89
# Run optimization
90
study = optuna.create_study(direction='maximize')
91
study.optimize(objective, n_trials=10)
92
```
93
94
### VecNormalize Saving Callback
95
96
Callback for automatically saving VecNormalize statistics during training, ensuring normalization parameters are preserved with the model.
97
98
```python { .api }
99
class SaveVecNormalizeCallback(BaseCallback):
100
"""
101
Callback for saving VecNormalize statistics at regular intervals.
102
"""
103
104
def __init__(
105
self,
106
save_freq: int,
107
save_path: str,
108
name_prefix: str = "vecnormalize",
109
verbose: int = 0
110
):
111
"""
112
Initialize SaveVecNormalizeCallback.
113
114
Parameters:
115
- save_freq: Frequency of saving (in timesteps)
116
- save_path: Directory path for saving statistics
117
- name_prefix: Prefix for saved files
118
- verbose: Verbosity level
119
"""
120
121
def _on_step(self) -> bool:
122
"""
123
Called at each training step. Saves VecNormalize stats at specified frequency.
124
125
Returns:
126
bool: Always True (never stops training)
127
"""
128
```
129
130
Usage example:
131
```python
132
from rl_zoo3.callbacks import SaveVecNormalizeCallback
133
from stable_baselines3.common.vec_env import VecNormalize
134
from stable_baselines3 import PPO
135
136
# Create normalized environment
137
env = VecNormalize(base_env, norm_obs=True, norm_reward=True)
138
139
# Create callback
140
save_callback = SaveVecNormalizeCallback(
141
save_freq=5000,
142
save_path="./logs/",
143
name_prefix="ppo_vecnormalize",
144
verbose=1
145
)
146
147
# Train with callback
148
model = PPO("MlpPolicy", env, verbose=1)
149
model.learn(total_timesteps=20000, callback=save_callback)
150
```
151
152
### Parallel Training Callback
153
154
Callback for coordinating parallel training processes, handling synchronization and communication between multiple training instances.
155
156
```python { .api }
157
class ParallelTrainCallback(BaseCallback):
158
"""
159
Callback for parallel training coordination.
160
Handles synchronization between multiple training processes.
161
"""
162
163
def __init__(self, verbose: int = 0):
164
"""
165
Initialize ParallelTrainCallback.
166
167
Parameters:
168
- verbose: Verbosity level
169
"""
170
171
def _on_training_start(self) -> None:
172
"""
173
Called when training starts. Sets up parallel coordination.
174
"""
175
176
def _on_step(self) -> bool:
177
"""
178
Called at each training step. Handles parallel synchronization.
179
180
Returns:
181
bool: Whether training should continue
182
"""
183
184
def _on_training_end(self) -> None:
185
"""
186
Called when training ends. Cleans up parallel resources.
187
"""
188
```
189
190
### Raw Statistics Callback
191
192
Callback for logging detailed training statistics and metrics, providing comprehensive monitoring of training progress.
193
194
```python { .api }
195
class RawStatisticsCallback(BaseCallback):
196
"""
197
Callback for logging raw training statistics.
198
Provides detailed metrics beyond standard Stable Baselines3 logging.
199
"""
200
201
def __init__(
202
self,
203
verbose: int = 0,
204
log_freq: int = 1000
205
):
206
"""
207
Initialize RawStatisticsCallback.
208
209
Parameters:
210
- verbose: Verbosity level
211
- log_freq: Frequency of detailed logging (in timesteps)
212
"""
213
214
def _on_step(self) -> bool:
215
"""
216
Called at each training step. Logs detailed statistics at specified frequency.
217
218
Returns:
219
bool: Always True (never stops training)
220
"""
221
222
def _on_rollout_end(self) -> None:
223
"""
224
Called at the end of each rollout. Logs rollout statistics.
225
"""
226
```
227
228
Usage example:
229
```python
230
from rl_zoo3.callbacks import RawStatisticsCallback
231
from stable_baselines3 import PPO
232
233
# Create callback
234
stats_callback = RawStatisticsCallback(
235
verbose=1,
236
log_freq=1000
237
)
238
239
# Train with detailed statistics logging
240
model = PPO("MlpPolicy", env, verbose=1)
241
model.learn(total_timesteps=50000, callback=stats_callback)
242
```
243
244
## Callback Combinations
245
246
You can combine multiple callbacks for comprehensive training monitoring:
247
248
```python
249
from rl_zoo3.callbacks import SaveVecNormalizeCallback, RawStatisticsCallback
250
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList
251
252
# Create multiple callbacks
253
checkpoint_callback = CheckpointCallback(
254
save_freq=10000,
255
save_path="./checkpoints/",
256
name_prefix="ppo_model"
257
)
258
259
eval_callback = EvalCallback(
260
eval_env=eval_env,
261
best_model_save_path="./best_models/",
262
log_path="./eval_logs/",
263
eval_freq=5000,
264
n_eval_episodes=10,
265
deterministic=True
266
)
267
268
save_vec_callback = SaveVecNormalizeCallback(
269
save_freq=5000,
270
save_path="./vecnormalize/",
271
verbose=1
272
)
273
274
stats_callback = RawStatisticsCallback(
275
verbose=1,
276
log_freq=1000
277
)
278
279
# Combine callbacks
280
callback_list = CallbackList([
281
checkpoint_callback,
282
eval_callback,
283
save_vec_callback,
284
stats_callback
285
])
286
287
# Train with all callbacks
288
model = PPO("MlpPolicy", env, verbose=1)
289
model.learn(total_timesteps=100000, callback=callback_list)
290
```
291
292
## Integration with ExperimentManager
293
294
The ExperimentManager automatically creates and configures appropriate callbacks based on the training setup:
295
296
```python
297
from rl_zoo3.exp_manager import ExperimentManager
298
import argparse
299
300
# ExperimentManager will automatically create callbacks based on configuration
301
args = argparse.Namespace(
302
algo='ppo',
303
env='CartPole-v1',
304
n_timesteps=50000,
305
eval_freq=5000, # Will create EvalCallback
306
save_freq=10000, # Will create CheckpointCallback
307
verbose=1
308
)
309
310
exp_manager = ExperimentManager(
311
args=args,
312
algo='ppo',
313
env_id='CartPole-v1',
314
log_folder='./logs',
315
n_timesteps=50000,
316
eval_freq=5000,
317
save_freq=10000
318
)
319
320
# Callbacks are automatically created and configured
321
model = exp_manager.setup_experiment()
322
exp_manager.learn(model) # Uses automatically created callbacks
323
```
324
325
## Custom Callback Integration
326
327
You can also create custom callbacks that work with RL Zoo3's training system:
328
329
```python
330
from stable_baselines3.common.callbacks import BaseCallback
331
332
class CustomCallback(BaseCallback):
333
"""
334
Custom callback for specific training needs.
335
"""
336
337
def __init__(self, verbose=0):
338
super().__init__(verbose)
339
self.custom_metric = 0
340
341
def _on_training_start(self) -> None:
342
print("Custom training logic started")
343
344
def _on_step(self) -> bool:
345
# Custom logic here
346
self.custom_metric += 1
347
348
# Log custom metrics
349
if self.n_calls % 1000 == 0:
350
self.logger.record("custom/metric", self.custom_metric)
351
352
return True # Continue training
353
354
# Use custom callback with ExperimentManager
355
# You would typically integrate this through hyperparams configuration
356
# or by modifying the ExperimentManager's create_callbacks method
357
```