0
# Common Framework
1
2
Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the Stable Baselines3 library. This framework promotes code reuse and ensures uniform interfaces across different algorithm implementations.
3
4
## Capabilities
5
6
### Base Algorithm Classes
7
8
Abstract base classes that define the core functionality shared by all reinforcement learning algorithms, including training loops, model management, and prediction interfaces.
9
10
```python { .api }
11
class BaseAlgorithm:
12
"""
13
Abstract base class for all RL algorithms.
14
15
Args:
16
policy: Policy class or string identifier
17
env: Environment or environment ID
18
learning_rate: Learning rate for optimization
19
policy_kwargs: Additional arguments for policy construction
20
stats_window_size: Window size for rollout logging averaging
21
tensorboard_log: Path to TensorBoard log directory
22
verbose: Verbosity level (0: no output, 1: info, 2: debug)
23
device: PyTorch device placement ("auto", "cpu", "cuda")
24
support_multi_env: Whether algorithm supports multiple environments
25
monitor_wrapper: Whether to wrap environment with Monitor
26
seed: Random seed for reproducibility
27
use_sde: Whether to use State Dependent Exploration
28
sde_sample_freq: Sample frequency for SDE
29
supported_action_spaces: List of supported action spaces
30
"""
31
def __init__(
32
self,
33
policy: Union[str, Type[BasePolicy]],
34
env: Union[GymEnv, str],
35
learning_rate: Union[float, Schedule],
36
policy_kwargs: Optional[Dict[str, Any]] = None,
37
stats_window_size: int = 100,
38
tensorboard_log: Optional[str] = None,
39
verbose: int = 0,
40
device: Union[torch.device, str] = "auto",
41
support_multi_env: bool = False,
42
monitor_wrapper: bool = True,
43
seed: Optional[int] = None,
44
use_sde: bool = False,
45
sde_sample_freq: int = -1,
46
supported_action_spaces: Optional[Tuple[Type[gym.Space], ...]] = None,
47
): ...
48
49
def learn(
50
self,
51
total_timesteps: int,
52
callback: MaybeCallback = None,
53
log_interval: int = 4,
54
tb_log_name: str = "run",
55
reset_num_timesteps: bool = True,
56
progress_bar: bool = False,
57
) -> "BaseAlgorithm":
58
"""
59
Train the agent for total_timesteps.
60
61
Args:
62
total_timesteps: Total number of timesteps to train
63
callback: Callback(s) called during training
64
log_interval: Log interval for training metrics
65
tb_log_name: TensorBoard log name
66
reset_num_timesteps: Reset timestep counter
67
progress_bar: Display progress bar
68
69
Returns:
70
Trained algorithm instance
71
"""
72
73
def predict(
74
self,
75
observation: Union[np.ndarray, Dict[str, np.ndarray]],
76
state: Optional[Tuple[np.ndarray, ...]] = None,
77
episode_start: Optional[np.ndarray] = None,
78
deterministic: bool = False,
79
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
80
"""
81
Get action from observation.
82
83
Args:
84
observation: Input observation
85
state: Hidden state for recurrent policies
86
episode_start: Start of episode mask
87
deterministic: Use deterministic actions
88
89
Returns:
90
Tuple of (action, next_state)
91
"""
92
93
def save(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
94
"""Save model to file path."""
95
96
@classmethod
97
def load(
98
cls,
99
path: Union[str, pathlib.Path, io.BufferedIOBase],
100
env: Optional[GymEnv] = None,
101
device: Union[torch.device, str] = "auto",
102
custom_objects: Optional[Dict[str, Any]] = None,
103
print_system_info: bool = False,
104
force_reset: bool = True,
105
**kwargs,
106
) -> "BaseAlgorithm":
107
"""Load model from file path."""
108
109
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
110
"""Set new environment for the algorithm."""
111
112
def get_env(self) -> Optional[VecEnv]:
113
"""Get current environment."""
114
115
def set_random_seed(self, seed: Optional[int] = None) -> None:
116
"""Set random seed for reproducibility."""
117
118
class OnPolicyAlgorithm(BaseAlgorithm):
119
"""
120
Base class for on-policy algorithms (A2C, PPO).
121
122
Additional Args:
123
n_steps: Number of steps per environment per update
124
gamma: Discount factor
125
gae_lambda: GAE lambda parameter
126
ent_coef: Entropy coefficient
127
vf_coef: Value function coefficient
128
max_grad_norm: Maximum gradient norm for clipping
129
"""
130
def __init__(
131
self,
132
policy: Union[str, Type[ActorCriticPolicy]],
133
env: Union[GymEnv, str],
134
learning_rate: Union[float, Schedule],
135
n_steps: int,
136
gamma: float,
137
gae_lambda: float,
138
ent_coef: float,
139
vf_coef: float,
140
max_grad_norm: float,
141
use_sde: bool = False,
142
sde_sample_freq: int = -1,
143
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
144
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
145
**kwargs,
146
): ...
147
148
def collect_rollouts(
149
self,
150
env: VecEnv,
151
callback: BaseCallback,
152
rollout_buffer: RolloutBuffer,
153
n_rollout_steps: int,
154
) -> bool:
155
"""Collect rollout data from environment."""
156
157
class OffPolicyAlgorithm(BaseAlgorithm):
158
"""
159
Base class for off-policy algorithms (SAC, TD3, DDPG, DQN).
160
161
Additional Args:
162
buffer_size: Replay buffer size
163
learning_starts: Steps before learning starts
164
batch_size: Minibatch size for training
165
tau: Soft update coefficient for target networks
166
train_freq: Training frequency
167
gradient_steps: Gradient steps per training
168
action_noise: Action noise for exploration
169
replay_buffer_class: Replay buffer class
170
replay_buffer_kwargs: Additional buffer arguments
171
optimize_memory_usage: Enable memory optimizations
172
"""
173
def __init__(
174
self,
175
policy: Union[str, Type[BasePolicy]],
176
env: Union[GymEnv, str],
177
learning_rate: Union[float, Schedule],
178
buffer_size: int = 1_000_000,
179
learning_starts: int = 100,
180
batch_size: int = 256,
181
tau: float = 0.005,
182
gamma: float = 0.99,
183
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
184
gradient_steps: int = 1,
185
action_noise: Optional[ActionNoise] = None,
186
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
187
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
188
optimize_memory_usage: bool = False,
189
**kwargs,
190
): ...
191
192
def _sample_action(
193
self,
194
learning_starts: int,
195
action_noise: Optional[ActionNoise] = None,
196
n_envs: int = 1,
197
) -> Tuple[np.ndarray, np.ndarray]:
198
"""Sample action with exploration noise."""
199
```
200
201
### Policy Base Classes
202
203
Neural network architectures that define how observations are processed and actions are selected, supporting different observation spaces and algorithm requirements.
204
205
```python { .api }
206
class BaseModel(torch.nn.Module):
207
"""
208
Base class for all neural network models.
209
210
Args:
211
observation_space: Observation space
212
action_space: Action space
213
lr_schedule: Learning rate schedule
214
use_sde: Whether to use State Dependent Exploration
215
log_std_init: Initial log standard deviation
216
full_std: Use full covariance matrix for SDE
217
sde_net_arch: Network architecture for SDE
218
use_expln: Use exponential activation for variance
219
squash_output: Squash output with tanh
220
features_extractor_class: Feature extractor class
221
features_extractor_kwargs: Feature extractor arguments
222
share_features_extractor: Share feature extractor between actor/critic
223
normalize_images: Normalize image observations
224
optimizer_class: Optimizer class
225
optimizer_kwargs: Optimizer arguments
226
"""
227
def __init__(
228
self,
229
observation_space: gym.spaces.Space,
230
action_space: gym.spaces.Space,
231
lr_schedule: Schedule,
232
use_sde: bool = False,
233
log_std_init: float = 0.0,
234
full_std: bool = True,
235
sde_net_arch: Optional[List[int]] = None,
236
use_expln: bool = False,
237
squash_output: bool = False,
238
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
239
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
240
share_features_extractor: bool = True,
241
normalize_images: bool = True,
242
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
243
optimizer_kwargs: Optional[Dict[str, Any]] = None,
244
): ...
245
246
def forward(self, *args, **kwargs) -> torch.Tensor:
247
"""Forward pass through the network."""
248
249
class BasePolicy(BaseModel):
250
"""
251
Base policy class for all algorithms.
252
253
Args:
254
observation_space: Observation space
255
action_space: Action space
256
lr_schedule: Learning rate schedule
257
use_sde: Whether to use State Dependent Exploration
258
**kwargs: Additional arguments passed to BaseModel
259
"""
260
def __init__(
261
self,
262
observation_space: gym.spaces.Space,
263
action_space: gym.spaces.Space,
264
lr_schedule: Schedule,
265
use_sde: bool = False,
266
**kwargs,
267
): ...
268
269
def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
270
"""Get action from observation."""
271
272
def predict(
273
self,
274
observation: Union[np.ndarray, Dict[str, np.ndarray]],
275
state: Optional[Tuple[np.ndarray, ...]] = None,
276
episode_start: Optional[np.ndarray] = None,
277
deterministic: bool = False,
278
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
279
"""Get action and state from observation."""
280
281
def _predict(
282
self, observation: torch.Tensor, deterministic: bool = False
283
) -> torch.Tensor:
284
"""Internal prediction method."""
285
286
class ActorCriticPolicy(BasePolicy):
287
"""
288
Policy with both actor and critic networks for on-policy algorithms.
289
290
Args:
291
observation_space: Observation space
292
action_space: Action space
293
lr_schedule: Learning rate schedule
294
net_arch: Network architecture specification
295
activation_fn: Activation function
296
ortho_init: Use orthogonal initialization
297
use_sde: Whether to use State Dependent Exploration
298
log_std_init: Initial log standard deviation
299
full_std: Use full covariance matrix for SDE
300
sde_net_arch: Network architecture for SDE
301
use_expln: Use exponential activation for variance
302
squash_output: Squash output with tanh
303
features_extractor_class: Feature extractor class
304
features_extractor_kwargs: Feature extractor arguments
305
share_features_extractor: Share feature extractor between actor/critic
306
normalize_images: Normalize image observations
307
optimizer_class: Optimizer class
308
optimizer_kwargs: Optimizer arguments
309
"""
310
def __init__(
311
self,
312
observation_space: gym.spaces.Space,
313
action_space: gym.spaces.Space,
314
lr_schedule: Schedule,
315
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
316
activation_fn: Type[torch.nn.Module] = torch.nn.Tanh,
317
ortho_init: bool = True,
318
use_sde: bool = False,
319
log_std_init: float = 0.0,
320
full_std: bool = True,
321
sde_net_arch: Optional[List[int]] = None,
322
use_expln: bool = False,
323
squash_output: bool = False,
324
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
325
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
326
share_features_extractor: bool = True,
327
normalize_images: bool = True,
328
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
329
optimizer_kwargs: Optional[Dict[str, Any]] = None,
330
): ...
331
332
def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
333
"""Forward pass through actor network."""
334
335
def evaluate_actions(
336
self, obs: torch.Tensor, actions: torch.Tensor
337
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
338
"""Evaluate actions for training."""
339
340
def get_distribution(self, obs: torch.Tensor) -> Distribution:
341
"""Get action distribution from observation."""
342
343
def predict_values(self, obs: torch.Tensor) -> torch.Tensor:
344
"""Predict state values using critic network."""
345
```
346
347
### Common Policy Aliases
348
349
All algorithms provide convenient aliases for their policy classes to simplify usage:
350
351
```python { .api }
352
# Standard policy aliases (available in each algorithm module)
353
MlpPolicy = ActorCriticPolicy # For algorithms like A2C, PPO
354
CnnPolicy = ActorCriticCnnPolicy # For image-based observations
355
MultiInputPolicy = MultiInputActorCriticPolicy # For dict observations
356
357
# Import examples:
358
from stable_baselines3.ppo import MlpPolicy as PPOMlpPolicy
359
from stable_baselines3.a2c import CnnPolicy as A2CCnnPolicy
360
from stable_baselines3.sac import MlpPolicy as SACMlpPolicy
361
```
362
363
### Experience Buffers
364
365
Storage mechanisms for training data that enable different sampling strategies and memory management approaches for various algorithm types.
366
367
```python { .api }
368
class BaseBuffer:
369
"""
370
Abstract base class for all experience buffers.
371
372
Args:
373
buffer_size: Maximum buffer capacity
374
observation_space: Observation space
375
action_space: Action space
376
device: PyTorch device placement
377
n_envs: Number of parallel environments
378
"""
379
def __init__(
380
self,
381
buffer_size: int,
382
observation_space: gym.spaces.Space,
383
action_space: gym.spaces.Space,
384
device: Union[torch.device, str] = "auto",
385
n_envs: int = 1,
386
): ...
387
388
def add(self, *args, **kwargs) -> None:
389
"""Add experience to buffer."""
390
391
def get(self, *args, **kwargs) -> Any:
392
"""Sample experience from buffer."""
393
394
def reset(self) -> None:
395
"""Reset buffer to empty state."""
396
397
def size(self) -> int:
398
"""Current buffer size."""
399
400
class RolloutBuffer(BaseBuffer):
401
"""
402
Buffer for on-policy algorithms that stores rollout trajectories.
403
404
Args:
405
buffer_size: Buffer capacity (typically n_steps * n_envs)
406
observation_space: Observation space
407
action_space: Action space
408
device: PyTorch device placement
409
gae_lambda: GAE lambda parameter
410
gamma: Discount factor
411
n_envs: Number of parallel environments
412
"""
413
def __init__(
414
self,
415
buffer_size: int,
416
observation_space: gym.spaces.Space,
417
action_space: gym.spaces.Space,
418
device: Union[torch.device, str] = "auto",
419
gae_lambda: float = 1,
420
gamma: float = 0.99,
421
n_envs: int = 1,
422
): ...
423
424
def add(
425
self,
426
obs: np.ndarray,
427
actions: np.ndarray,
428
rewards: np.ndarray,
429
episode_starts: np.ndarray,
430
values: torch.Tensor,
431
log_probs: torch.Tensor,
432
) -> None:
433
"""
434
Add rollout data to buffer.
435
436
Args:
437
obs: Observations
438
actions: Actions taken
439
rewards: Rewards received
440
episode_starts: Episode start flags
441
values: State value estimates
442
log_probs: Action log probabilities
443
"""
444
445
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
446
"""
447
Sample batches from buffer.
448
449
Args:
450
batch_size: Size of batches to sample
451
452
Yields:
453
Batches of rollout data
454
"""
455
456
def compute_returns_and_advantage(
457
self, last_values: torch.Tensor, dones: np.ndarray
458
) -> None:
459
"""
460
Compute returns and advantages using GAE.
461
462
Args:
463
last_values: Value estimates for final states
464
dones: Episode termination flags
465
"""
466
467
class ReplayBuffer(BaseBuffer):
468
"""
469
Experience replay buffer for off-policy algorithms.
470
471
Args:
472
buffer_size: Maximum buffer capacity
473
observation_space: Observation space
474
action_space: Action space
475
device: PyTorch device placement
476
n_envs: Number of parallel environments
477
optimize_memory_usage: Enable memory optimizations
478
handle_timeout_termination: Handle timeout terminations properly
479
"""
480
def __init__(
481
self,
482
buffer_size: int,
483
observation_space: gym.spaces.Space,
484
action_space: gym.spaces.Space,
485
device: Union[torch.device, str] = "auto",
486
n_envs: int = 1,
487
optimize_memory_usage: bool = False,
488
handle_timeout_termination: bool = True,
489
): ...
490
491
def add(
492
self,
493
obs: np.ndarray,
494
next_obs: np.ndarray,
495
actions: np.ndarray,
496
rewards: np.ndarray,
497
dones: np.ndarray,
498
infos: List[Dict[str, Any]],
499
) -> None:
500
"""
501
Add transition to replay buffer.
502
503
Args:
504
obs: Current observations
505
next_obs: Next observations
506
actions: Actions taken
507
rewards: Rewards received
508
dones: Episode termination flags
509
infos: Additional information
510
"""
511
512
def sample(self, batch_size: int, env: Optional[VecEnv] = None) -> ReplayBufferSamples:
513
"""
514
Sample batch of transitions.
515
516
Args:
517
batch_size: Number of transitions to sample
518
env: Environment for normalization
519
520
Returns:
521
Batch of experience samples
522
"""
523
524
class DictRolloutBuffer(RolloutBuffer):
525
"""Rollout buffer for dictionary observations."""
526
527
class DictReplayBuffer(ReplayBuffer):
528
"""Replay buffer for dictionary observations."""
529
```
530
531
## Usage Examples
532
533
### Custom Policy Architecture
534
535
```python
536
from stable_baselines3 import PPO
537
from stable_baselines3.common.policies import ActorCriticPolicy
538
import torch.nn as nn
539
540
# Define custom network architecture
541
policy_kwargs = dict(
542
net_arch=[dict(pi=[128, 128], vf=[128, 128])],
543
activation_fn=nn.ReLU,
544
ortho_init=True,
545
)
546
547
model = PPO(
548
"MlpPolicy",
549
env,
550
policy_kwargs=policy_kwargs,
551
verbose=1
552
)
553
```
554
555
### Custom Buffer Configuration
556
557
```python
558
from stable_baselines3 import SAC
559
from stable_baselines3.common.buffers import ReplayBuffer
560
561
# Custom replay buffer settings
562
replay_buffer_kwargs = dict(
563
optimize_memory_usage=True,
564
handle_timeout_termination=True,
565
)
566
567
model = SAC(
568
"MlpPolicy",
569
env,
570
buffer_size=500000,
571
replay_buffer_kwargs=replay_buffer_kwargs,
572
verbose=1
573
)
574
```
575
576
### Accessing Buffer Data
577
578
```python
579
# Access replay buffer after training
580
replay_buffer = model.replay_buffer
581
582
# Sample transitions for analysis
583
batch = replay_buffer.sample(batch_size=256)
584
observations = batch.observations
585
actions = batch.actions
586
rewards = batch.rewards
587
```
588
589
## Types
590
591
```python { .api }
592
from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple
593
import numpy as np
594
import torch
595
import gymnasium as gym
596
import pathlib
597
import io
598
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback, PyTorchObs, TensorDict
599
from stable_baselines3.common.policies import BasePolicy, ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy
600
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
601
from stable_baselines3.common.noise import ActionNoise
602
from stable_baselines3.common.vec_env import VecEnv
603
```