0
# Training Utilities
1
2
Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes. These components provide essential functionality for experiment management, hyperparameter tuning, and production deployment of RL systems.
3
4
## Capabilities
5
6
### Callback System
7
8
Event-driven system for monitoring, evaluating, and controlling training processes with customizable hooks at various training stages.
9
10
```python { .api }
11
class BaseCallback:
12
"""
13
Abstract base class for training callbacks.
14
15
Args:
16
verbose: Verbosity level (0: quiet, 1: info, 2: debug)
17
"""
18
def __init__(self, verbose: int = 0): ...
19
20
def init_callback(self, model: "BaseAlgorithm") -> None:
21
"""Initialize callback with algorithm instance."""
22
23
def on_training_start(
24
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
25
) -> None:
26
"""Called when training begins."""
27
28
def on_rollout_start(self) -> None:
29
"""Called before collecting rollouts."""
30
31
def on_step(self) -> bool:
32
"""
33
Called after each environment step.
34
35
Returns:
36
True to continue training, False to stop
37
"""
38
39
def on_rollout_end(self) -> None:
40
"""Called after rollout collection."""
41
42
def on_training_end(self) -> None:
43
"""Called when training ends."""
44
45
def update_locals(self, locals_: Dict[str, Any]) -> None:
46
"""Update callback with current local variables."""
47
48
class EventCallback(BaseCallback):
49
"""
50
Base class for event-triggered callbacks.
51
52
Args:
53
callback: Child callback to trigger
54
verbose: Verbosity level
55
"""
56
def __init__(self, callback: Optional["BaseCallback"] = None, verbose: int = 0): ...
57
58
def _trigger_event(self) -> bool:
59
"""Trigger child callback if conditions are met."""
60
61
def _on_event(self) -> bool:
62
"""Event handler (to be implemented by subclasses)."""
63
64
class CallbackList(BaseCallback):
65
"""
66
Container for multiple callbacks.
67
68
Args:
69
callbacks: List of callback instances
70
"""
71
def __init__(self, callbacks: List[BaseCallback]): ...
72
73
def on_training_start(
74
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
75
) -> None:
76
"""Call on_training_start for all callbacks."""
77
78
def on_step(self) -> bool:
79
"""Call on_step for all callbacks, stop if any returns False."""
80
81
class EvalCallback(EventCallback):
82
"""
83
Evaluate agent during training and save best model.
84
85
Args:
86
eval_env: Environment for evaluation
87
callback_on_new_best: Callback triggered when new best model found
88
callback_after_eval: Callback triggered after evaluation
89
n_eval_episodes: Number of episodes for evaluation
90
eval_freq: Evaluation frequency (steps)
91
log_path: Path for evaluation logs
92
best_model_save_path: Path to save best model
93
deterministic: Use deterministic actions during evaluation
94
render: Render evaluation episodes
95
verbose: Verbosity level
96
warn: Show warnings for evaluation issues
97
"""
98
def __init__(
99
self,
100
eval_env: Union[gym.Env, VecEnv],
101
callback_on_new_best: Optional[BaseCallback] = None,
102
callback_after_eval: Optional[BaseCallback] = None,
103
n_eval_episodes: int = 5,
104
eval_freq: int = 10000,
105
log_path: Optional[str] = None,
106
best_model_save_path: Optional[str] = None,
107
deterministic: bool = True,
108
render: bool = False,
109
verbose: int = 1,
110
warn: bool = True,
111
): ...
112
113
def _on_step(self) -> bool:
114
"""Evaluate model if eval_freq steps have passed."""
115
116
def _on_event(self) -> bool:
117
"""Perform evaluation and save best model."""
118
119
class CheckpointCallback(BaseCallback):
120
"""
121
Save model at regular intervals.
122
123
Args:
124
save_freq: Frequency for saving checkpoints (steps)
125
save_path: Directory to save checkpoints
126
name_prefix: Prefix for checkpoint filenames
127
save_replay_buffer: Whether to save replay buffer
128
save_vecnormalize: Whether to save VecNormalize statistics
129
verbose: Verbosity level
130
"""
131
def __init__(
132
self,
133
save_freq: int,
134
save_path: str,
135
name_prefix: str = "rl_model",
136
save_replay_buffer: bool = False,
137
save_vecnormalize: bool = False,
138
verbose: int = 0,
139
): ...
140
141
def _on_step(self) -> bool:
142
"""Save checkpoint if save_freq steps have passed."""
143
144
class StopTrainingOnRewardThreshold(BaseCallback):
145
"""
146
Stop training when reward threshold is reached.
147
148
Args:
149
reward_threshold: Minimum average reward to stop training
150
verbose: Verbosity level
151
"""
152
def __init__(self, reward_threshold: float, verbose: int = 0): ...
153
154
def _on_step(self) -> bool:
155
"""Check if reward threshold is reached."""
156
157
class StopTrainingOnMaxEpisodes(BaseCallback):
158
"""
159
Stop training after maximum number of episodes.
160
161
Args:
162
max_episodes: Maximum number of episodes
163
verbose: Verbosity level
164
"""
165
def __init__(self, max_episodes: int, verbose: int = 0): ...
166
167
def _on_step(self) -> bool:
168
"""Check if maximum episodes reached."""
169
170
class ProgressBarCallback(BaseCallback):
171
"""
172
Display training progress bar using tqdm.
173
174
Args:
175
refresh_freq: Progress bar refresh frequency
176
"""
177
def __init__(self, refresh_freq: int = 1): ...
178
179
def on_training_start(
180
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
181
) -> None:
182
"""Initialize progress bar."""
183
184
def _on_step(self) -> bool:
185
"""Update progress bar."""
186
187
class EveryNTimesteps(EventCallback):
188
"""
189
Trigger a callback every n timesteps.
190
191
Args:
192
n_steps: Number of timesteps between triggers
193
callback: Callback to trigger
194
"""
195
def __init__(self, n_steps: int, callback: BaseCallback): ...
196
197
class ConvertCallback(BaseCallback):
198
"""
199
Convert functional callback (old-style) to object.
200
201
Args:
202
callback: Optional callback function
203
verbose: Verbosity level
204
"""
205
def __init__(self, callback: Optional[Callable], verbose: int = 0): ...
206
207
class StopTrainingOnNoModelImprovement(BaseCallback):
208
"""
209
Stop training if no new best model after N consecutive evaluations.
210
Must be used with EvalCallback.
211
212
Args:
213
max_no_improvement_evals: Max consecutive evaluations without improvement
214
min_evals: Number of evaluations before counting
215
verbose: Verbosity level
216
"""
217
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): ...
218
```
219
220
### Noise Classes
221
222
Action noise generators for exploration in continuous control environments, providing various stochastic processes for effective exploration strategies.
223
224
```python { .api }
225
class ActionNoise:
226
"""Abstract base class for action noise."""
227
228
def __call__(self) -> np.ndarray:
229
"""Generate noise sample."""
230
231
def reset(self) -> None:
232
"""Reset noise state."""
233
234
class NormalActionNoise(ActionNoise):
235
"""
236
Gaussian action noise for exploration.
237
238
Args:
239
mean: Mean of the noise distribution
240
sigma: Standard deviation of the noise distribution
241
"""
242
def __init__(self, mean: np.ndarray, sigma: np.ndarray): ...
243
244
def __call__(self) -> np.ndarray:
245
"""Sample from Gaussian distribution."""
246
247
def reset(self) -> None:
248
"""Reset noise (no-op for memoryless noise)."""
249
250
class OrnsteinUhlenbeckActionNoise(ActionNoise):
251
"""
252
Ornstein-Uhlenbeck process noise for temporally correlated exploration.
253
254
Args:
255
mean: Long-run mean of the process
256
sigma: Volatility parameter
257
theta: Rate of mean reversion
258
dt: Time step
259
initial_noise: Initial noise value
260
"""
261
def __init__(
262
self,
263
mean: np.ndarray,
264
sigma: np.ndarray,
265
theta: float = 0.15,
266
dt: float = 1e-2,
267
initial_noise: Optional[np.ndarray] = None,
268
): ...
269
270
def __call__(self) -> np.ndarray:
271
"""Sample next noise value from OU process."""
272
273
def reset(self) -> None:
274
"""Reset process to initial state."""
275
276
class VectorizedActionNoise(ActionNoise):
277
"""
278
Vectorized noise for multiple environments.
279
280
Args:
281
noise_fn: Function to create noise instances
282
n_envs: Number of environments
283
"""
284
def __init__(self, noise_fn: Callable[[], ActionNoise], n_envs: int): ...
285
286
def __call__(self) -> np.ndarray:
287
"""Sample noise for all environments."""
288
289
def reset(self) -> None:
290
"""Reset all noise instances."""
291
```
292
293
### Evaluation Functions
294
295
Comprehensive evaluation utilities for assessing trained agents, including statistical analysis and performance monitoring across multiple episodes.
296
297
```python { .api }
298
def evaluate_policy(
299
model: "BaseAlgorithm",
300
env: Union[gym.Env, VecEnv],
301
n_eval_episodes: int = 10,
302
deterministic: bool = True,
303
render: bool = False,
304
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
305
reward_threshold: Optional[float] = None,
306
return_episode_rewards: bool = False,
307
warn: bool = True,
308
verbose: int = 1,
309
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
310
"""
311
Evaluate trained agent on environment.
312
313
Args:
314
model: Trained RL algorithm
315
env: Environment for evaluation
316
n_eval_episodes: Number of episodes to evaluate
317
deterministic: Use deterministic actions
318
render: Render environment during evaluation
319
callback: Custom callback function
320
reward_threshold: Minimum reward threshold for success
321
return_episode_rewards: Return individual episode rewards
322
warn: Show warnings
323
verbose: Verbosity level
324
325
Returns:
326
Mean reward and standard deviation, or episode rewards and lengths
327
"""
328
```
329
330
### Utility Functions
331
332
General-purpose utilities for reproducibility, tensor operations, mathematical computations, and other common tasks in RL training.
333
334
```python { .api }
335
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
336
"""
337
Set random seed for reproducibility.
338
339
Args:
340
seed: Random seed value
341
using_cuda: Whether CUDA is being used
342
"""
343
344
def get_system_info() -> Dict[str, Any]:
345
"""
346
Get system information for debugging.
347
348
Returns:
349
Dictionary containing system information
350
"""
351
352
def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
353
"""
354
Get PyTorch device from string specification.
355
356
Args:
357
device: Device specification ("auto", "cpu", "cuda", etc.)
358
359
Returns:
360
PyTorch device object
361
"""
362
363
def obs_as_tensor(
364
obs: Union[np.ndarray, Dict[str, np.ndarray]], device: torch.device
365
) -> Union[torch.Tensor, TensorDict]:
366
"""
367
Convert observations to PyTorch tensors.
368
369
Args:
370
obs: Observations to convert
371
device: Target device for tensors
372
373
Returns:
374
Tensor or dictionary of tensors
375
"""
376
377
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
378
"""
379
Calculate explained variance.
380
381
Args:
382
y_pred: Predicted values
383
y_true: True values
384
385
Returns:
386
Explained variance
387
"""
388
389
def polyak_update(
390
params: Iterable[torch.nn.Parameter],
391
target_params: Iterable[torch.nn.Parameter],
392
tau: float,
393
) -> None:
394
"""
395
Polyak averaging for target network updates.
396
397
Args:
398
params: Source parameters
399
target_params: Target parameters to update
400
tau: Update coefficient (0 = no update, 1 = hard update)
401
"""
402
403
def update_learning_rate(optimizer: torch.optim.Optimizer, learning_rate: float) -> None:
404
"""
405
Update optimizer learning rate.
406
407
Args:
408
optimizer: PyTorch optimizer
409
learning_rate: New learning rate
410
"""
411
412
def safe_mean(arr: List[float]) -> np.ndarray:
413
"""
414
Calculate mean safely, handling empty arrays.
415
416
Args:
417
arr: Array of values
418
419
Returns:
420
Mean value or NaN if array is empty
421
"""
422
```
423
424
### Schedule Functions
425
426
Learning rate and parameter scheduling utilities for adaptive training dynamics and improved convergence behavior.
427
428
```python { .api }
429
def get_schedule_fn(value_schedule: Union[float, str, Schedule]) -> Schedule:
430
"""
431
Convert value to schedule function.
432
433
Args:
434
value_schedule: Constant value, string identifier, or schedule function
435
436
Returns:
437
Schedule function
438
"""
439
440
def get_linear_fn(
441
start: float, end: float, end_fraction: float
442
) -> Callable[[float], float]:
443
"""
444
Create linear schedule function.
445
446
Args:
447
start: Initial value
448
end: Final value
449
end_fraction: Fraction of training when end value is reached
450
451
Returns:
452
Linear schedule function
453
"""
454
455
def constant_fn(val: float) -> Callable[[float], float]:
456
"""
457
Create constant schedule function.
458
459
Args:
460
val: Constant value
461
462
Returns:
463
Constant schedule function
464
"""
465
```
466
467
## Usage Examples
468
469
### Comprehensive Training Setup
470
471
```python
472
import gymnasium as gym
473
from stable_baselines3 import PPO
474
from stable_baselines3.common.callbacks import (
475
EvalCallback, CheckpointCallback, StopTrainingOnRewardThreshold
476
)
477
from stable_baselines3.common.vec_env import DummyVecEnv
478
479
# Create training and evaluation environments
480
train_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])
481
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
482
483
# Set up callbacks
484
eval_callback = EvalCallback(
485
eval_env,
486
best_model_save_path="./logs/best_model/",
487
log_path="./logs/results/",
488
eval_freq=10000,
489
n_eval_episodes=5,
490
deterministic=True,
491
render=False
492
)
493
494
checkpoint_callback = CheckpointCallback(
495
save_freq=50000,
496
save_path="./logs/checkpoints/",
497
name_prefix="ppo_cartpole",
498
save_replay_buffer=False,
499
save_vecnormalize=False
500
)
501
502
stop_callback = StopTrainingOnRewardThreshold(
503
reward_threshold=195.0, # CartPole-v1 is considered solved at 195
504
verbose=1
505
)
506
507
# Combine callbacks
508
from stable_baselines3.common.callbacks import CallbackList
509
callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])
510
511
# Train with callbacks
512
model = PPO("MlpPolicy", train_env, verbose=1)
513
model.learn(total_timesteps=100000, callback=callback)
514
```
515
516
### Action Noise for Continuous Control
517
518
```python
519
import numpy as np
520
from stable_baselines3 import TD3
521
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
522
523
# Gaussian noise
524
n_actions = env.action_space.shape[-1]
525
action_noise = NormalActionNoise(
526
mean=np.zeros(n_actions),
527
sigma=0.1 * np.ones(n_actions)
528
)
529
530
# Ornstein-Uhlenbeck noise for temporally correlated exploration
531
action_noise = OrnsteinUhlenbeckActionNoise(
532
mean=np.zeros(n_actions),
533
sigma=0.1 * np.ones(n_actions),
534
theta=0.15,
535
dt=1e-2
536
)
537
538
model = TD3(
539
"MlpPolicy",
540
env,
541
action_noise=action_noise,
542
verbose=1
543
)
544
545
model.learn(total_timesteps=100000)
546
```
547
548
### Custom Evaluation and Analysis
549
550
```python
551
from stable_baselines3.common.evaluation import evaluate_policy
552
import matplotlib.pyplot as plt
553
554
# Detailed evaluation
555
episode_rewards, episode_lengths = evaluate_policy(
556
model,
557
eval_env,
558
n_eval_episodes=100,
559
deterministic=True,
560
return_episode_rewards=True
561
)
562
563
# Statistical analysis
564
print(f"Mean reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
565
print(f"Mean episode length: {np.mean(episode_lengths):.2f}")
566
567
# Plot results
568
plt.figure(figsize=(12, 4))
569
plt.subplot(1, 2, 1)
570
plt.hist(episode_rewards, bins=20)
571
plt.xlabel("Episode Reward")
572
plt.ylabel("Frequency")
573
574
plt.subplot(1, 2, 2)
575
plt.hist(episode_lengths, bins=20)
576
plt.xlabel("Episode Length")
577
plt.ylabel("Frequency")
578
plt.show()
579
```
580
581
### Learning Rate Scheduling
582
583
```python
584
from stable_baselines3.common.utils import get_linear_fn
585
586
# Linear learning rate decay
587
learning_rate = get_linear_fn(3e-4, 1e-5, 1.0)
588
589
model = PPO(
590
"MlpPolicy",
591
env,
592
learning_rate=learning_rate,
593
verbose=1
594
)
595
596
# Custom schedule function
597
def custom_schedule(progress_remaining):
598
"""Custom learning rate schedule."""
599
if progress_remaining > 0.5:
600
return 3e-4
601
else:
602
return 1e-4
603
604
model = PPO(
605
"MlpPolicy",
606
env,
607
learning_rate=custom_schedule,
608
verbose=1
609
)
610
```
611
612
### Custom Callback Creation
613
614
```python
615
from stable_baselines3.common.callbacks import BaseCallback
616
617
class LoggingCallback(BaseCallback):
618
"""Custom callback for additional logging."""
619
620
def __init__(self, verbose=0):
621
super(LoggingCallback, self).__init__(verbose)
622
self.episode_rewards = []
623
624
def _on_step(self) -> bool:
625
# Access training variables
626
if len(self.model.ep_info_buffer) > 0:
627
mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer])
628
self.logger.record("custom/mean_episode_reward", mean_reward)
629
630
return True # Continue training
631
632
# Use custom callback
633
custom_callback = LoggingCallback(verbose=1)
634
model.learn(total_timesteps=100000, callback=custom_callback)
635
```
636
637
## Types
638
639
```python { .api }
640
from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
641
import numpy as np
642
import gymnasium as gym
643
from stable_baselines3.common.callbacks import BaseCallback, EventCallback, CallbackList
644
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
645
from stable_baselines3.common.noise import ActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise
646
from stable_baselines3.common.vec_env import VecEnv
647
from stable_baselines3.common.base_class import BaseAlgorithm
648
```