0
# Stable Baselines3
1
2
Stable Baselines3 is a comprehensive Python library providing reliable implementations of state-of-the-art reinforcement learning algorithms using PyTorch. It offers a unified sklearn-like interface for multiple RL algorithms with extensive customization options, designed to facilitate both research and practical deployment of deep reinforcement learning solutions.
3
4
## Package Information
5
6
- **Package Name**: stable-baselines3
7
- **Language**: Python
8
- **Installation**: `pip install stable-baselines3`
9
10
## Core Imports
11
12
```python
13
from stable_baselines3 import PPO, A2C, SAC, TD3, DDPG, DQN
14
```
15
16
Common utilities and components:
17
18
```python
19
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
20
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
21
from stable_baselines3.common.noise import NormalActionNoise
22
from stable_baselines3.common.utils import set_random_seed, get_system_info
23
```
24
25
## Basic Usage
26
27
```python
28
import gymnasium as gym
29
from stable_baselines3 import PPO
30
from stable_baselines3.common.vec_env import DummyVecEnv
31
from stable_baselines3.common.callbacks import EvalCallback
32
33
# Create environment
34
env = gym.make("CartPole-v1")
35
eval_env = gym.make("CartPole-v1")
36
37
# Wrap in vectorized environment for training
38
env = DummyVecEnv([lambda: env])
39
eval_env = DummyVecEnv([lambda: eval_env])
40
41
# Create PPO agent
42
model = PPO("MlpPolicy", env, verbose=1)
43
44
# Set up evaluation callback
45
eval_callback = EvalCallback(
46
eval_env,
47
best_model_save_path="./logs/",
48
log_path="./logs/",
49
eval_freq=10000,
50
deterministic=True,
51
render=False
52
)
53
54
# Train the agent
55
model.learn(total_timesteps=100000, callback=eval_callback)
56
57
# Save the trained model
58
model.save("ppo_cartpole")
59
60
# Load and use the trained model
61
model = PPO.load("ppo_cartpole")
62
63
# Test the trained agent
64
obs = env.reset()
65
for i in range(1000):
66
action, _states = model.predict(obs, deterministic=True)
67
obs, reward, done, info = env.step(action)
68
if done:
69
obs = env.reset()
70
```
71
72
## Architecture
73
74
Stable Baselines3 follows a hierarchical architecture that promotes code reuse and extensibility:
75
76
- **Algorithm Classes**: Implementation of specific RL algorithms (PPO, SAC, etc.)
77
- **Base Classes**: Abstract base classes providing common functionality
78
- `BaseAlgorithm`: Core training loop and model management
79
- `OnPolicyAlgorithm`: Base for algorithms like A2C and PPO
80
- `OffPolicyAlgorithm`: Base for algorithms like SAC, TD3, DDPG, DQN
81
- **Policies**: Neural network architectures for different observation spaces
82
- **Buffers**: Experience storage for training (rollout buffers, replay buffers)
83
- **Common Components**: Utilities, callbacks, environment wrappers, and evaluation tools
84
85
This design enables consistent interfaces across algorithms while allowing for algorithm-specific optimizations.
86
87
## Capabilities
88
89
### Core Algorithms
90
91
Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options.
92
93
```python { .api }
94
class PPO(OnPolicyAlgorithm):
95
"""Proximal Policy Optimization algorithm."""
96
def __init__(self, policy, env, learning_rate=3e-4, n_steps=2048, batch_size=64,
97
n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, **kwargs): ...
98
99
class A2C(OnPolicyAlgorithm):
100
"""Advantage Actor-Critic algorithm."""
101
def __init__(self, policy, env, learning_rate=7e-4, n_steps=5, gamma=0.99,
102
gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, **kwargs): ...
103
104
class SAC(OffPolicyAlgorithm):
105
"""Soft Actor-Critic algorithm."""
106
def __init__(self, policy, env, learning_rate=3e-4, buffer_size=1000000,
107
batch_size=256, tau=0.005, gamma=0.99, **kwargs): ...
108
109
class TD3(OffPolicyAlgorithm):
110
"""Twin Delayed Deep Deterministic Policy Gradient algorithm."""
111
def __init__(self, policy, env, learning_rate=1e-3, buffer_size=1000000,
112
batch_size=100, tau=0.005, gamma=0.99, **kwargs): ...
113
114
class DDPG(TD3):
115
"""Deep Deterministic Policy Gradient algorithm."""
116
def __init__(self, policy, env, **kwargs): ...
117
118
class DQN(OffPolicyAlgorithm):
119
"""Deep Q-Network algorithm."""
120
def __init__(self, policy, env, learning_rate=1e-4, buffer_size=1000000,
121
batch_size=32, tau=1.0, gamma=0.99, **kwargs): ...
122
```
123
124
[Core Algorithms](./algorithms.md)
125
126
### Common Framework
127
128
Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the library.
129
130
```python { .api }
131
class BaseAlgorithm:
132
"""Abstract base class for all RL algorithms."""
133
def learn(self, total_timesteps, callback=None, log_interval=4,
134
tb_log_name="run", reset_num_timesteps=True, progress_bar=False): ...
135
def predict(self, observation, state=None, episode_start=None, deterministic=False): ...
136
def save(self, path): ...
137
@classmethod
138
def load(cls, path, env=None, device="auto", **kwargs): ...
139
140
class BasePolicy:
141
"""Base policy class for all neural network policies."""
142
def forward(self, obs, deterministic=False): ...
143
def predict(self, observation, state=None, episode_start=None, deterministic=False): ...
144
145
class RolloutBuffer:
146
"""Buffer for on-policy algorithms."""
147
def add(self, obs, actions, rewards, episode_starts, values, log_probs): ...
148
def get(self, batch_size=None): ...
149
150
class ReplayBuffer:
151
"""Experience replay buffer for off-policy algorithms."""
152
def add(self, obs, next_obs, actions, rewards, dones, infos): ...
153
def sample(self, batch_size, env=None): ...
154
```
155
156
[Common Framework](./common-framework.md)
157
158
### Vectorized Environments
159
160
Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks.
161
162
```python { .api }
163
class DummyVecEnv:
164
"""Sequential vectorized environment."""
165
def __init__(self, env_fns): ...
166
def reset(self): ...
167
def step(self, actions): ...
168
169
class SubprocVecEnv:
170
"""Multiprocessing vectorized environment."""
171
def __init__(self, env_fns, start_method=None): ...
172
173
class VecNormalize:
174
"""Normalize observations and rewards."""
175
def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
176
clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8): ...
177
178
class VecFrameStack:
179
"""Stack frames for recurrent policies."""
180
def __init__(self, venv, n_stack, channels_order="last"): ...
181
```
182
183
[Environments](./environments.md)
184
185
### Training Utilities
186
187
Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes.
188
189
```python { .api }
190
class EvalCallback:
191
"""Evaluate model during training."""
192
def __init__(self, eval_env, callback_on_new_best=None, n_eval_episodes=5,
193
eval_freq=10000, log_path=None, best_model_save_path=None,
194
deterministic=True, render=False, verbose=1): ...
195
196
class CheckpointCallback:
197
"""Save model at regular intervals."""
198
def __init__(self, save_freq, save_path, name_prefix="rl_model",
199
save_replay_buffer=False, save_vecnormalize=False, verbose=0): ...
200
201
class NormalActionNoise:
202
"""Gaussian action noise for exploration."""
203
def __init__(self, mean, sigma): ...
204
205
def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
206
render=False, callback=None, reward_threshold=None,
207
return_episode_rewards=False, warn=True, verbose=1): ...
208
```
209
210
[Training Utilities](./training-utilities.md)
211
212
### Hindsight Experience Replay (HER)
213
214
Implementation of Hindsight Experience Replay for goal-conditioned reinforcement learning, enabling learning from failed attempts by treating them as successful attempts toward different goals.
215
216
```python { .api }
217
class HerReplayBuffer:
218
"""Replay buffer with Hindsight Experience Replay."""
219
def __init__(self, buffer_size, observation_space, action_space, device="auto",
220
n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True,
221
n_sampled_goal=4, goal_selection_strategy="future", wrapped_env=None): ...
222
223
class GoalSelectionStrategy:
224
"""Enumeration of goal selection strategies."""
225
FUTURE = "future"
226
FINAL = "final"
227
EPISODE = "episode"
228
RANDOM = "random"
229
```
230
231
[HER](./her.md)
232
233
### System Information
234
235
Utilities for retrieving system and environment information for debugging and reproducibility.
236
237
```python { .api }
238
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
239
"""
240
Retrieve system and python env info for the current system.
241
242
Parameters:
243
- print_info: Whether to print or not those infos
244
245
Returns:
246
Tuple containing dictionary with version info and formatted string
247
"""
248
```
249
250
## Types
251
252
```python { .api }
253
from typing import Union, Optional, Callable, Dict, Any, List, Tuple
254
import numpy as np
255
import torch
256
import gymnasium as gym
257
258
# Environment types
259
GymEnv = Union[gym.Env, gym.Wrapper]
260
VecEnv = Union[DummyVecEnv, SubprocVecEnv]
261
262
# Policy types
263
Schedule = Callable[[float], float]
264
MaybeCallback = Union[None, Callable, List[Callable], "BaseCallback"]
265
266
# Algorithm-specific policy types
267
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
268
from stable_baselines3.sac.policies import SACPolicy
269
from stable_baselines3.td3.policies import TD3Policy
270
from stable_baselines3.dqn.policies import DQNPolicy
271
272
# Buffer types
273
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer
274
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
275
276
# Noise types
277
from stable_baselines3.common.noise import ActionNoise
278
279
# Observation and action types
280
PyTorchObs = Union[torch.Tensor, Dict[str, torch.Tensor]]
281
TensorDict = Dict[str, torch.Tensor]
282
283
# Training frequency specification
284
TrainFreq = Union[int, Tuple[int, str]]
285
```