0
# Core Algorithms
1
2
Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options. Each algorithm is optimized for specific types of environments and learning scenarios.
3
4
## Capabilities
5
6
### Proximal Policy Optimization (PPO)
7
8
On-policy algorithm that optimizes a clipped surrogate objective to ensure stable policy updates. Suitable for both continuous and discrete action spaces with excellent sample efficiency and stability.
9
10
```python { .api }
11
class PPO(OnPolicyAlgorithm):
12
"""
13
Proximal Policy Optimization algorithm.
14
15
Args:
16
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
17
env: Environment or environment ID
18
learning_rate: Learning rate, can be a function of remaining progress
19
n_steps: Number of steps to run for each environment per update
20
batch_size: Minibatch size
21
n_epochs: Number of epochs when optimizing the surrogate loss
22
gamma: Discount factor
23
gae_lambda: Factor for trade-off of bias vs variance for GAE
24
clip_range: Clipping parameter for PPO surrogate objective
25
clip_range_vf: Clipping parameter for value function
26
normalize_advantage: Whether to normalize advantages
27
ent_coef: Entropy coefficient for exploration
28
vf_coef: Value function coefficient for loss calculation
29
max_grad_norm: Maximum value for gradient clipping
30
use_sde: Whether to use State Dependent Exploration
31
sde_sample_freq: Sample frequency for SDE
32
rollout_buffer_class: Rollout buffer class to use (None for default)
33
rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
34
target_kl: Limit KL divergence between updates
35
stats_window_size: Window size for rollout logging averaging
36
tensorboard_log: Path to TensorBoard log directory
37
policy_kwargs: Additional arguments for policy construction
38
verbose: Verbosity level (0: no output, 1: info, 2: debug)
39
seed: Seed for random number generator
40
device: PyTorch device placement ("auto", "cpu", "cuda")
41
_init_setup_model: Whether to build network at creation
42
"""
43
def __init__(
44
self,
45
policy: Union[str, Type[ActorCriticPolicy]],
46
env: Union[GymEnv, str],
47
learning_rate: Union[float, Schedule] = 3e-4,
48
n_steps: int = 2048,
49
batch_size: int = 64,
50
n_epochs: int = 10,
51
gamma: float = 0.99,
52
gae_lambda: float = 0.95,
53
clip_range: Union[float, Schedule] = 0.2,
54
clip_range_vf: Optional[Union[float, Schedule]] = None,
55
normalize_advantage: bool = True,
56
ent_coef: float = 0.0,
57
vf_coef: float = 0.5,
58
max_grad_norm: float = 0.5,
59
use_sde: bool = False,
60
sde_sample_freq: int = -1,
61
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
62
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
63
target_kl: Optional[float] = None,
64
stats_window_size: int = 100,
65
tensorboard_log: Optional[str] = None,
66
policy_kwargs: Optional[Dict[str, Any]] = None,
67
verbose: int = 0,
68
seed: Optional[int] = None,
69
device: Union[torch.device, str] = "auto",
70
_init_setup_model: bool = True,
71
): ...
72
```
73
74
### Advantage Actor-Critic (A2C)
75
76
On-policy algorithm that combines value-based and policy-based methods. Synchronous version of A3C with simpler implementation and often better performance than async counterparts.
77
78
```python { .api }
79
class A2C(OnPolicyAlgorithm):
80
"""
81
Advantage Actor-Critic algorithm.
82
83
Args:
84
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
85
env: Environment or environment ID
86
learning_rate: Learning rate, can be a function of remaining progress
87
n_steps: Number of steps to run for each environment per update
88
gamma: Discount factor
89
gae_lambda: Factor for trade-off of bias vs variance for GAE
90
ent_coef: Entropy coefficient for exploration
91
vf_coef: Value function coefficient
92
max_grad_norm: Maximum value for gradient clipping
93
rms_prop_eps: RMSprop optimizer epsilon
94
use_rms_prop: Whether to use RMSprop optimizer (vs Adam)
95
use_sde: Whether to use State Dependent Exploration
96
sde_sample_freq: Sample frequency for SDE
97
rollout_buffer_class: Rollout buffer class to use (None for default)
98
rollout_buffer_kwargs: Keyword arguments for rollout buffer creation
99
normalize_advantage: Whether to normalize advantages
100
stats_window_size: Window size for rollout logging averaging
101
tensorboard_log: Path to TensorBoard log directory
102
policy_kwargs: Additional arguments for policy construction
103
verbose: Verbosity level
104
seed: Seed for random number generator
105
device: PyTorch device placement
106
_init_setup_model: Whether to build network at creation
107
"""
108
def __init__(
109
self,
110
policy: Union[str, Type[ActorCriticPolicy]],
111
env: Union[GymEnv, str],
112
learning_rate: Union[float, Schedule] = 7e-4,
113
n_steps: int = 5,
114
gamma: float = 0.99,
115
gae_lambda: float = 1.0,
116
ent_coef: float = 0.0,
117
vf_coef: float = 0.5,
118
max_grad_norm: float = 0.5,
119
rms_prop_eps: float = 1e-5,
120
use_rms_prop: bool = True,
121
use_sde: bool = False,
122
sde_sample_freq: int = -1,
123
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
124
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
125
normalize_advantage: bool = False,
126
stats_window_size: int = 100,
127
tensorboard_log: Optional[str] = None,
128
policy_kwargs: Optional[Dict[str, Any]] = None,
129
verbose: int = 0,
130
seed: Optional[int] = None,
131
device: Union[torch.device, str] = "auto",
132
_init_setup_model: bool = True,
133
): ...
134
```
135
136
### Soft Actor-Critic (SAC)
137
138
Off-policy algorithm that incorporates entropy regularization to encourage exploration. Particularly effective for continuous control tasks with excellent sample efficiency and stability.
139
140
```python { .api }
141
class SAC(OffPolicyAlgorithm):
142
"""
143
Soft Actor-Critic algorithm.
144
145
Args:
146
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
147
env: Environment or environment ID
148
learning_rate: Learning rate, can be a function of remaining progress
149
buffer_size: Size of replay buffer
150
learning_starts: Steps before learning starts
151
batch_size: Minibatch size for training
152
tau: Soft update coefficient for target networks
153
gamma: Discount factor
154
train_freq: Update policy every n steps or episodes
155
gradient_steps: Gradient steps per update
156
action_noise: Action noise for exploration
157
replay_buffer_class: Replay buffer class
158
replay_buffer_kwargs: Additional replay buffer arguments
159
optimize_memory_usage: Enable memory optimizations
160
n_steps: Number of steps for n-step return calculation
161
ent_coef: Entropy regularization coefficient
162
target_update_interval: Update target network every n gradient steps
163
target_entropy: Target entropy for automatic entropy tuning
164
use_sde: Whether to use State Dependent Exploration
165
sde_sample_freq: Sample frequency for SDE
166
use_sde_at_warmup: Use SDE instead of uniform sampling during warmup
167
stats_window_size: Window size for rollout logging averaging
168
tensorboard_log: Path to TensorBoard log directory
169
policy_kwargs: Additional arguments for policy construction
170
verbose: Verbosity level
171
seed: Seed for random number generator
172
device: PyTorch device placement
173
_init_setup_model: Whether to build network at creation
174
"""
175
def __init__(
176
self,
177
policy: Union[str, Type[SACPolicy]],
178
env: Union[GymEnv, str],
179
learning_rate: Union[float, Schedule] = 3e-4,
180
buffer_size: int = 1_000_000,
181
learning_starts: int = 100,
182
batch_size: int = 256,
183
tau: float = 0.005,
184
gamma: float = 0.99,
185
train_freq: Union[int, Tuple[int, str]] = 1,
186
gradient_steps: int = 1,
187
action_noise: Optional[ActionNoise] = None,
188
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
189
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
190
optimize_memory_usage: bool = False,
191
n_steps: int = 1,
192
ent_coef: Union[str, float] = "auto",
193
target_update_interval: int = 1,
194
target_entropy: Union[str, float] = "auto",
195
use_sde: bool = False,
196
sde_sample_freq: int = -1,
197
use_sde_at_warmup: bool = False,
198
stats_window_size: int = 100,
199
tensorboard_log: Optional[str] = None,
200
policy_kwargs: Optional[Dict[str, Any]] = None,
201
verbose: int = 0,
202
seed: Optional[int] = None,
203
device: Union[torch.device, str] = "auto",
204
_init_setup_model: bool = True,
205
): ...
206
```
207
208
### Twin Delayed Deep Deterministic Policy Gradient (TD3)
209
210
Off-policy algorithm that addresses the overestimation bias in DDPG through twin critics and delayed policy updates. Excellent for continuous control with improved stability over DDPG.
211
212
```python { .api }
213
class TD3(OffPolicyAlgorithm):
214
"""
215
Twin Delayed Deep Deterministic Policy Gradient algorithm.
216
217
Args:
218
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
219
env: Environment or environment ID
220
learning_rate: Learning rate, can be a function of remaining progress
221
buffer_size: Size of replay buffer
222
learning_starts: Steps before learning starts
223
batch_size: Minibatch size for training
224
tau: Soft update coefficient for target networks
225
gamma: Discount factor
226
train_freq: Update policy every n steps or episodes
227
gradient_steps: Gradient steps per update
228
action_noise: Action noise for exploration
229
replay_buffer_class: Replay buffer class
230
replay_buffer_kwargs: Additional replay buffer arguments
231
optimize_memory_usage: Enable memory optimizations
232
n_steps: Number of steps for n-step return calculation
233
policy_delay: Policy update delay (TD3 specific)
234
target_policy_noise: Noise added to target policy
235
target_noise_clip: Range to clip target policy noise
236
stats_window_size: Window size for rollout logging averaging
237
tensorboard_log: Path to TensorBoard log directory
238
policy_kwargs: Additional arguments for policy construction
239
verbose: Verbosity level
240
seed: Seed for random number generator
241
device: PyTorch device placement
242
_init_setup_model: Whether to build network at creation
243
"""
244
def __init__(
245
self,
246
policy: Union[str, Type[TD3Policy]],
247
env: Union[GymEnv, str],
248
learning_rate: Union[float, Schedule] = 1e-3,
249
buffer_size: int = 1_000_000,
250
learning_starts: int = 100,
251
batch_size: int = 256,
252
tau: float = 0.005,
253
gamma: float = 0.99,
254
train_freq: Union[int, Tuple[int, str]] = 1,
255
gradient_steps: int = 1,
256
action_noise: Optional[ActionNoise] = None,
257
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
258
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
259
optimize_memory_usage: bool = False,
260
n_steps: int = 1,
261
policy_delay: int = 2,
262
target_policy_noise: float = 0.2,
263
target_noise_clip: float = 0.5,
264
stats_window_size: int = 100,
265
tensorboard_log: Optional[str] = None,
266
policy_kwargs: Optional[Dict[str, Any]] = None,
267
verbose: int = 0,
268
seed: Optional[int] = None,
269
device: Union[torch.device, str] = "auto",
270
_init_setup_model: bool = True,
271
): ...
272
```
273
274
### Deep Deterministic Policy Gradient (DDPG)
275
276
Off-policy algorithm for continuous control that combines DQN with policy gradients. Implemented as a special case of TD3 without the twin critics and delayed updates.
277
278
```python { .api }
279
class DDPG(TD3):
280
"""
281
Deep Deterministic Policy Gradient algorithm.
282
283
Args:
284
Same as TD3 but with different default values:
285
- policy_delay: 1 (immediate policy updates)
286
- target_policy_noise: 0.0 (no target policy noise)
287
- target_noise_clip: 0.0 (no noise clipping)
288
"""
289
def __init__(
290
self,
291
policy: Union[str, Type[TD3Policy]],
292
env: Union[GymEnv, str],
293
learning_rate: Union[float, Schedule] = 1e-4,
294
buffer_size: int = 1_000_000,
295
learning_starts: int = 100,
296
batch_size: int = 100,
297
tau: float = 0.005,
298
gamma: float = 0.99,
299
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
300
gradient_steps: int = -1,
301
action_noise: Optional[ActionNoise] = None,
302
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
303
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
304
optimize_memory_usage: bool = False,
305
tensorboard_log: Optional[str] = None,
306
policy_kwargs: Optional[Dict[str, Any]] = None,
307
verbose: int = 0,
308
seed: Optional[int] = None,
309
device: Union[torch.device, str] = "auto",
310
_init_setup_model: bool = True,
311
): ...
312
```
313
314
### Deep Q-Network (DQN)
315
316
Off-policy value-based algorithm for discrete action spaces. Uses experience replay and target networks to stabilize learning of Q-values.
317
318
```python { .api }
319
class DQN(OffPolicyAlgorithm):
320
"""
321
Deep Q-Network algorithm.
322
323
Args:
324
policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
325
env: Environment or environment ID
326
learning_rate: Learning rate, can be a function of remaining progress
327
buffer_size: Size of replay buffer
328
learning_starts: Steps before learning starts
329
batch_size: Minibatch size for training
330
tau: Soft update coefficient (1.0 = hard update)
331
gamma: Discount factor
332
train_freq: Update policy every n steps
333
gradient_steps: Gradient steps per update
334
replay_buffer_class: Replay buffer class
335
replay_buffer_kwargs: Additional replay buffer arguments
336
optimize_memory_usage: Enable memory optimizations
337
n_steps: Number of steps for n-step return calculation
338
target_update_interval: Hard update interval for target network
339
exploration_fraction: Fraction of training for exploration decay
340
exploration_initial_eps: Initial exploration probability
341
exploration_final_eps: Final exploration probability
342
max_grad_norm: Maximum gradient norm
343
stats_window_size: Window size for rollout logging averaging
344
tensorboard_log: Path to TensorBoard log directory
345
policy_kwargs: Additional arguments for policy construction
346
verbose: Verbosity level
347
seed: Seed for random number generator
348
device: PyTorch device placement
349
_init_setup_model: Whether to build network at creation
350
"""
351
def __init__(
352
self,
353
policy: Union[str, Type[DQNPolicy]],
354
env: Union[GymEnv, str],
355
learning_rate: Union[float, Schedule] = 1e-4,
356
buffer_size: int = 1_000_000,
357
learning_starts: int = 100,
358
batch_size: int = 32,
359
tau: float = 1.0,
360
gamma: float = 0.99,
361
train_freq: Union[int, Tuple[int, str]] = 4,
362
gradient_steps: int = 1,
363
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
364
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
365
optimize_memory_usage: bool = False,
366
n_steps: int = 1,
367
target_update_interval: int = 10000,
368
exploration_fraction: float = 0.1,
369
exploration_initial_eps: float = 1.0,
370
exploration_final_eps: float = 0.05,
371
max_grad_norm: float = 10,
372
stats_window_size: int = 100,
373
tensorboard_log: Optional[str] = None,
374
policy_kwargs: Optional[Dict[str, Any]] = None,
375
verbose: int = 0,
376
seed: Optional[int] = None,
377
device: Union[torch.device, str] = "auto",
378
_init_setup_model: bool = True,
379
): ...
380
```
381
382
## Policy Types
383
384
All algorithms support three standard policy architectures that can be specified by string or class:
385
386
```python { .api }
387
# Multi-layer perceptron policy for vector observations
388
MlpPolicy = "MlpPolicy"
389
390
# Convolutional neural network policy for image observations
391
CnnPolicy = "CnnPolicy"
392
393
# Multi-input policy for dictionary observations
394
MultiInputPolicy = "MultiInputPolicy"
395
```
396
397
## Usage Examples
398
399
### Basic Algorithm Training
400
401
```python
402
import gymnasium as gym
403
from stable_baselines3 import PPO
404
405
# Create environment and agent
406
env = gym.make("CartPole-v1")
407
model = PPO("MlpPolicy", env, verbose=1)
408
409
# Train the agent
410
model.learn(total_timesteps=25000)
411
412
# Use the trained agent
413
obs, info = env.reset()
414
for i in range(1000):
415
action, _states = model.predict(obs, deterministic=True)
416
obs, reward, terminated, truncated, info = env.step(action)
417
if terminated or truncated:
418
obs, info = env.reset()
419
```
420
421
### Custom Policy Networks
422
423
```python
424
from stable_baselines3 import SAC
425
426
# Custom policy architecture
427
policy_kwargs = dict(
428
net_arch=dict(pi=[400, 300], qf=[400, 300]),
429
activation_fn=torch.nn.ReLU,
430
)
431
432
model = SAC(
433
"MlpPolicy",
434
env,
435
policy_kwargs=policy_kwargs,
436
learning_rate=3e-4,
437
buffer_size=1000000,
438
batch_size=256,
439
verbose=1
440
)
441
```
442
443
### Continuous Control with Noise
444
445
```python
446
import numpy as np
447
from stable_baselines3 import TD3
448
from stable_baselines3.common.noise import NormalActionNoise
449
450
# Create action noise for exploration
451
n_actions = env.action_space.shape[-1]
452
action_noise = NormalActionNoise(
453
mean=np.zeros(n_actions),
454
sigma=0.1 * np.ones(n_actions)
455
)
456
457
model = TD3(
458
"MlpPolicy",
459
env,
460
action_noise=action_noise,
461
verbose=1
462
)
463
464
model.learn(total_timesteps=100000)
465
```
466
467
## Types
468
469
```python { .api }
470
from typing import Union, Optional, Type, Callable, Dict, Any, Tuple
471
import numpy as np
472
import torch
473
import gymnasium as gym
474
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback
475
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
476
from stable_baselines3.common.base_class import BaseAlgorithm
477
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
478
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
479
from stable_baselines3.common.noise import ActionNoise
480
from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer
481
```