or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

tessl/pypi-stable-baselines3

Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/stable-baselines3@2.7.x

To install, run

npx @tessl/cli install tessl/pypi-stable-baselines3@2.7.0

0

# Stable Baselines3

1

2

Stable Baselines3 is a comprehensive Python library providing reliable implementations of state-of-the-art reinforcement learning algorithms using PyTorch. It offers a unified sklearn-like interface for multiple RL algorithms with extensive customization options, designed to facilitate both research and practical deployment of deep reinforcement learning solutions.

3

4

## Package Information

5

6

- **Package Name**: stable-baselines3

7

- **Language**: Python

8

- **Installation**: `pip install stable-baselines3`

9

10

## Core Imports

11

12

```python

13

from stable_baselines3 import PPO, A2C, SAC, TD3, DDPG, DQN

14

```

15

16

Common utilities and components:

17

18

```python

19

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback

20

from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

21

from stable_baselines3.common.noise import NormalActionNoise

22

from stable_baselines3.common.utils import set_random_seed, get_system_info

23

```

24

25

## Basic Usage

26

27

```python

28

import gymnasium as gym

29

from stable_baselines3 import PPO

30

from stable_baselines3.common.vec_env import DummyVecEnv

31

from stable_baselines3.common.callbacks import EvalCallback

32

33

# Create environment

34

env = gym.make("CartPole-v1")

35

eval_env = gym.make("CartPole-v1")

36

37

# Wrap in vectorized environment for training

38

env = DummyVecEnv([lambda: env])

39

eval_env = DummyVecEnv([lambda: eval_env])

40

41

# Create PPO agent

42

model = PPO("MlpPolicy", env, verbose=1)

43

44

# Set up evaluation callback

45

eval_callback = EvalCallback(

46

eval_env,

47

best_model_save_path="./logs/",

48

log_path="./logs/",

49

eval_freq=10000,

50

deterministic=True,

51

render=False

52

)

53

54

# Train the agent

55

model.learn(total_timesteps=100000, callback=eval_callback)

56

57

# Save the trained model

58

model.save("ppo_cartpole")

59

60

# Load and use the trained model

61

model = PPO.load("ppo_cartpole")

62

63

# Test the trained agent

64

obs = env.reset()

65

for i in range(1000):

66

action, _states = model.predict(obs, deterministic=True)

67

obs, reward, done, info = env.step(action)

68

if done:

69

obs = env.reset()

70

```

71

72

## Architecture

73

74

Stable Baselines3 follows a hierarchical architecture that promotes code reuse and extensibility:

75

76

- **Algorithm Classes**: Implementation of specific RL algorithms (PPO, SAC, etc.)

77

- **Base Classes**: Abstract base classes providing common functionality

78

- `BaseAlgorithm`: Core training loop and model management

79

- `OnPolicyAlgorithm`: Base for algorithms like A2C and PPO

80

- `OffPolicyAlgorithm`: Base for algorithms like SAC, TD3, DDPG, DQN

81

- **Policies**: Neural network architectures for different observation spaces

82

- **Buffers**: Experience storage for training (rollout buffers, replay buffers)

83

- **Common Components**: Utilities, callbacks, environment wrappers, and evaluation tools

84

85

This design enables consistent interfaces across algorithms while allowing for algorithm-specific optimizations.

86

87

## Capabilities

88

89

### Core Algorithms

90

91

Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options.

92

93

```python { .api }

94

class PPO(OnPolicyAlgorithm):

95

"""Proximal Policy Optimization algorithm."""

96

def __init__(self, policy, env, learning_rate=3e-4, n_steps=2048, batch_size=64,

97

n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, **kwargs): ...

98

99

class A2C(OnPolicyAlgorithm):

100

"""Advantage Actor-Critic algorithm."""

101

def __init__(self, policy, env, learning_rate=7e-4, n_steps=5, gamma=0.99,

102

gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, **kwargs): ...

103

104

class SAC(OffPolicyAlgorithm):

105

"""Soft Actor-Critic algorithm."""

106

def __init__(self, policy, env, learning_rate=3e-4, buffer_size=1000000,

107

batch_size=256, tau=0.005, gamma=0.99, **kwargs): ...

108

109

class TD3(OffPolicyAlgorithm):

110

"""Twin Delayed Deep Deterministic Policy Gradient algorithm."""

111

def __init__(self, policy, env, learning_rate=1e-3, buffer_size=1000000,

112

batch_size=100, tau=0.005, gamma=0.99, **kwargs): ...

113

114

class DDPG(TD3):

115

"""Deep Deterministic Policy Gradient algorithm."""

116

def __init__(self, policy, env, **kwargs): ...

117

118

class DQN(OffPolicyAlgorithm):

119

"""Deep Q-Network algorithm."""

120

def __init__(self, policy, env, learning_rate=1e-4, buffer_size=1000000,

121

batch_size=32, tau=1.0, gamma=0.99, **kwargs): ...

122

```

123

124

[Core Algorithms](./algorithms.md)

125

126

### Common Framework

127

128

Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the library.

129

130

```python { .api }

131

class BaseAlgorithm:

132

"""Abstract base class for all RL algorithms."""

133

def learn(self, total_timesteps, callback=None, log_interval=4,

134

tb_log_name="run", reset_num_timesteps=True, progress_bar=False): ...

135

def predict(self, observation, state=None, episode_start=None, deterministic=False): ...

136

def save(self, path): ...

137

@classmethod

138

def load(cls, path, env=None, device="auto", **kwargs): ...

139

140

class BasePolicy:

141

"""Base policy class for all neural network policies."""

142

def forward(self, obs, deterministic=False): ...

143

def predict(self, observation, state=None, episode_start=None, deterministic=False): ...

144

145

class RolloutBuffer:

146

"""Buffer for on-policy algorithms."""

147

def add(self, obs, actions, rewards, episode_starts, values, log_probs): ...

148

def get(self, batch_size=None): ...

149

150

class ReplayBuffer:

151

"""Experience replay buffer for off-policy algorithms."""

152

def add(self, obs, next_obs, actions, rewards, dones, infos): ...

153

def sample(self, batch_size, env=None): ...

154

```

155

156

[Common Framework](./common-framework.md)

157

158

### Vectorized Environments

159

160

Environment vectorization and wrappers for parallel training, normalization, monitoring, and other common preprocessing tasks.

161

162

```python { .api }

163

class DummyVecEnv:

164

"""Sequential vectorized environment."""

165

def __init__(self, env_fns): ...

166

def reset(self): ...

167

def step(self, actions): ...

168

169

class SubprocVecEnv:

170

"""Multiprocessing vectorized environment."""

171

def __init__(self, env_fns, start_method=None): ...

172

173

class VecNormalize:

174

"""Normalize observations and rewards."""

175

def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,

176

clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8): ...

177

178

class VecFrameStack:

179

"""Stack frames for recurrent policies."""

180

def __init__(self, venv, n_stack, channels_order="last"): ...

181

```

182

183

[Environments](./environments.md)

184

185

### Training Utilities

186

187

Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes.

188

189

```python { .api }

190

class EvalCallback:

191

"""Evaluate model during training."""

192

def __init__(self, eval_env, callback_on_new_best=None, n_eval_episodes=5,

193

eval_freq=10000, log_path=None, best_model_save_path=None,

194

deterministic=True, render=False, verbose=1): ...

195

196

class CheckpointCallback:

197

"""Save model at regular intervals."""

198

def __init__(self, save_freq, save_path, name_prefix="rl_model",

199

save_replay_buffer=False, save_vecnormalize=False, verbose=0): ...

200

201

class NormalActionNoise:

202

"""Gaussian action noise for exploration."""

203

def __init__(self, mean, sigma): ...

204

205

def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,

206

render=False, callback=None, reward_threshold=None,

207

return_episode_rewards=False, warn=True, verbose=1): ...

208

```

209

210

[Training Utilities](./training-utilities.md)

211

212

### Hindsight Experience Replay (HER)

213

214

Implementation of Hindsight Experience Replay for goal-conditioned reinforcement learning, enabling learning from failed attempts by treating them as successful attempts toward different goals.

215

216

```python { .api }

217

class HerReplayBuffer:

218

"""Replay buffer with Hindsight Experience Replay."""

219

def __init__(self, buffer_size, observation_space, action_space, device="auto",

220

n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True,

221

n_sampled_goal=4, goal_selection_strategy="future", wrapped_env=None): ...

222

223

class GoalSelectionStrategy:

224

"""Enumeration of goal selection strategies."""

225

FUTURE = "future"

226

FINAL = "final"

227

EPISODE = "episode"

228

RANDOM = "random"

229

```

230

231

[HER](./her.md)

232

233

### System Information

234

235

Utilities for retrieving system and environment information for debugging and reproducibility.

236

237

```python { .api }

238

def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:

239

"""

240

Retrieve system and python env info for the current system.

241

242

Parameters:

243

- print_info: Whether to print or not those infos

244

245

Returns:

246

Tuple containing dictionary with version info and formatted string

247

"""

248

```

249

250

## Types

251

252

```python { .api }

253

from typing import Union, Optional, Callable, Dict, Any, List, Tuple

254

import numpy as np

255

import torch

256

import gymnasium as gym

257

258

# Environment types

259

GymEnv = Union[gym.Env, gym.Wrapper]

260

VecEnv = Union[DummyVecEnv, SubprocVecEnv]

261

262

# Policy types

263

Schedule = Callable[[float], float]

264

MaybeCallback = Union[None, Callable, List[Callable], "BaseCallback"]

265

266

# Algorithm-specific policy types

267

from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy

268

from stable_baselines3.sac.policies import SACPolicy

269

from stable_baselines3.td3.policies import TD3Policy

270

from stable_baselines3.dqn.policies import DQNPolicy

271

272

# Buffer types

273

from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer

274

from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

275

276

# Noise types

277

from stable_baselines3.common.noise import ActionNoise

278

279

# Observation and action types

280

PyTorchObs = Union[torch.Tensor, Dict[str, torch.Tensor]]

281

TensorDict = Dict[str, torch.Tensor]

282

283

# Training frequency specification

284

TrainFreq = Union[int, Tuple[int, str]]

285

```