or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

algorithms.mdcommon-framework.mdenvironments.mdher.mdindex.mdtraining-utilities.md

algorithms.mddocs/

0

# Core Algorithms

1

2

Implementation of six state-of-the-art deep reinforcement learning algorithms with consistent interfaces and extensive configuration options. Each algorithm is optimized for specific types of environments and learning scenarios.

3

4

## Capabilities

5

6

### Proximal Policy Optimization (PPO)

7

8

On-policy algorithm that optimizes a clipped surrogate objective to ensure stable policy updates. Suitable for both continuous and discrete action spaces with excellent sample efficiency and stability.

9

10

```python { .api }

11

class PPO(OnPolicyAlgorithm):

12

"""

13

Proximal Policy Optimization algorithm.

14

15

Args:

16

policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")

17

env: Environment or environment ID

18

learning_rate: Learning rate, can be a function of remaining progress

19

n_steps: Number of steps to run for each environment per update

20

batch_size: Minibatch size

21

n_epochs: Number of epochs when optimizing the surrogate loss

22

gamma: Discount factor

23

gae_lambda: Factor for trade-off of bias vs variance for GAE

24

clip_range: Clipping parameter for PPO surrogate objective

25

clip_range_vf: Clipping parameter for value function

26

normalize_advantage: Whether to normalize advantages

27

ent_coef: Entropy coefficient for exploration

28

vf_coef: Value function coefficient for loss calculation

29

max_grad_norm: Maximum value for gradient clipping

30

use_sde: Whether to use State Dependent Exploration

31

sde_sample_freq: Sample frequency for SDE

32

rollout_buffer_class: Rollout buffer class to use (None for default)

33

rollout_buffer_kwargs: Keyword arguments for rollout buffer creation

34

target_kl: Limit KL divergence between updates

35

stats_window_size: Window size for rollout logging averaging

36

tensorboard_log: Path to TensorBoard log directory

37

policy_kwargs: Additional arguments for policy construction

38

verbose: Verbosity level (0: no output, 1: info, 2: debug)

39

seed: Seed for random number generator

40

device: PyTorch device placement ("auto", "cpu", "cuda")

41

_init_setup_model: Whether to build network at creation

42

"""

43

def __init__(

44

self,

45

policy: Union[str, Type[ActorCriticPolicy]],

46

env: Union[GymEnv, str],

47

learning_rate: Union[float, Schedule] = 3e-4,

48

n_steps: int = 2048,

49

batch_size: int = 64,

50

n_epochs: int = 10,

51

gamma: float = 0.99,

52

gae_lambda: float = 0.95,

53

clip_range: Union[float, Schedule] = 0.2,

54

clip_range_vf: Optional[Union[float, Schedule]] = None,

55

normalize_advantage: bool = True,

56

ent_coef: float = 0.0,

57

vf_coef: float = 0.5,

58

max_grad_norm: float = 0.5,

59

use_sde: bool = False,

60

sde_sample_freq: int = -1,

61

rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,

62

rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,

63

target_kl: Optional[float] = None,

64

stats_window_size: int = 100,

65

tensorboard_log: Optional[str] = None,

66

policy_kwargs: Optional[Dict[str, Any]] = None,

67

verbose: int = 0,

68

seed: Optional[int] = None,

69

device: Union[torch.device, str] = "auto",

70

_init_setup_model: bool = True,

71

): ...

72

```

73

74

### Advantage Actor-Critic (A2C)

75

76

On-policy algorithm that combines value-based and policy-based methods. Synchronous version of A3C with simpler implementation and often better performance than async counterparts.

77

78

```python { .api }

79

class A2C(OnPolicyAlgorithm):

80

"""

81

Advantage Actor-Critic algorithm.

82

83

Args:

84

policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")

85

env: Environment or environment ID

86

learning_rate: Learning rate, can be a function of remaining progress

87

n_steps: Number of steps to run for each environment per update

88

gamma: Discount factor

89

gae_lambda: Factor for trade-off of bias vs variance for GAE

90

ent_coef: Entropy coefficient for exploration

91

vf_coef: Value function coefficient

92

max_grad_norm: Maximum value for gradient clipping

93

rms_prop_eps: RMSprop optimizer epsilon

94

use_rms_prop: Whether to use RMSprop optimizer (vs Adam)

95

use_sde: Whether to use State Dependent Exploration

96

sde_sample_freq: Sample frequency for SDE

97

rollout_buffer_class: Rollout buffer class to use (None for default)

98

rollout_buffer_kwargs: Keyword arguments for rollout buffer creation

99

normalize_advantage: Whether to normalize advantages

100

stats_window_size: Window size for rollout logging averaging

101

tensorboard_log: Path to TensorBoard log directory

102

policy_kwargs: Additional arguments for policy construction

103

verbose: Verbosity level

104

seed: Seed for random number generator

105

device: PyTorch device placement

106

_init_setup_model: Whether to build network at creation

107

"""

108

def __init__(

109

self,

110

policy: Union[str, Type[ActorCriticPolicy]],

111

env: Union[GymEnv, str],

112

learning_rate: Union[float, Schedule] = 7e-4,

113

n_steps: int = 5,

114

gamma: float = 0.99,

115

gae_lambda: float = 1.0,

116

ent_coef: float = 0.0,

117

vf_coef: float = 0.5,

118

max_grad_norm: float = 0.5,

119

rms_prop_eps: float = 1e-5,

120

use_rms_prop: bool = True,

121

use_sde: bool = False,

122

sde_sample_freq: int = -1,

123

rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,

124

rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,

125

normalize_advantage: bool = False,

126

stats_window_size: int = 100,

127

tensorboard_log: Optional[str] = None,

128

policy_kwargs: Optional[Dict[str, Any]] = None,

129

verbose: int = 0,

130

seed: Optional[int] = None,

131

device: Union[torch.device, str] = "auto",

132

_init_setup_model: bool = True,

133

): ...

134

```

135

136

### Soft Actor-Critic (SAC)

137

138

Off-policy algorithm that incorporates entropy regularization to encourage exploration. Particularly effective for continuous control tasks with excellent sample efficiency and stability.

139

140

```python { .api }

141

class SAC(OffPolicyAlgorithm):

142

"""

143

Soft Actor-Critic algorithm.

144

145

Args:

146

policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")

147

env: Environment or environment ID

148

learning_rate: Learning rate, can be a function of remaining progress

149

buffer_size: Size of replay buffer

150

learning_starts: Steps before learning starts

151

batch_size: Minibatch size for training

152

tau: Soft update coefficient for target networks

153

gamma: Discount factor

154

train_freq: Update policy every n steps or episodes

155

gradient_steps: Gradient steps per update

156

action_noise: Action noise for exploration

157

replay_buffer_class: Replay buffer class

158

replay_buffer_kwargs: Additional replay buffer arguments

159

optimize_memory_usage: Enable memory optimizations

160

n_steps: Number of steps for n-step return calculation

161

ent_coef: Entropy regularization coefficient

162

target_update_interval: Update target network every n gradient steps

163

target_entropy: Target entropy for automatic entropy tuning

164

use_sde: Whether to use State Dependent Exploration

165

sde_sample_freq: Sample frequency for SDE

166

use_sde_at_warmup: Use SDE instead of uniform sampling during warmup

167

stats_window_size: Window size for rollout logging averaging

168

tensorboard_log: Path to TensorBoard log directory

169

policy_kwargs: Additional arguments for policy construction

170

verbose: Verbosity level

171

seed: Seed for random number generator

172

device: PyTorch device placement

173

_init_setup_model: Whether to build network at creation

174

"""

175

def __init__(

176

self,

177

policy: Union[str, Type[SACPolicy]],

178

env: Union[GymEnv, str],

179

learning_rate: Union[float, Schedule] = 3e-4,

180

buffer_size: int = 1_000_000,

181

learning_starts: int = 100,

182

batch_size: int = 256,

183

tau: float = 0.005,

184

gamma: float = 0.99,

185

train_freq: Union[int, Tuple[int, str]] = 1,

186

gradient_steps: int = 1,

187

action_noise: Optional[ActionNoise] = None,

188

replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

189

replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

190

optimize_memory_usage: bool = False,

191

n_steps: int = 1,

192

ent_coef: Union[str, float] = "auto",

193

target_update_interval: int = 1,

194

target_entropy: Union[str, float] = "auto",

195

use_sde: bool = False,

196

sde_sample_freq: int = -1,

197

use_sde_at_warmup: bool = False,

198

stats_window_size: int = 100,

199

tensorboard_log: Optional[str] = None,

200

policy_kwargs: Optional[Dict[str, Any]] = None,

201

verbose: int = 0,

202

seed: Optional[int] = None,

203

device: Union[torch.device, str] = "auto",

204

_init_setup_model: bool = True,

205

): ...

206

```

207

208

### Twin Delayed Deep Deterministic Policy Gradient (TD3)

209

210

Off-policy algorithm that addresses the overestimation bias in DDPG through twin critics and delayed policy updates. Excellent for continuous control with improved stability over DDPG.

211

212

```python { .api }

213

class TD3(OffPolicyAlgorithm):

214

"""

215

Twin Delayed Deep Deterministic Policy Gradient algorithm.

216

217

Args:

218

policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")

219

env: Environment or environment ID

220

learning_rate: Learning rate, can be a function of remaining progress

221

buffer_size: Size of replay buffer

222

learning_starts: Steps before learning starts

223

batch_size: Minibatch size for training

224

tau: Soft update coefficient for target networks

225

gamma: Discount factor

226

train_freq: Update policy every n steps or episodes

227

gradient_steps: Gradient steps per update

228

action_noise: Action noise for exploration

229

replay_buffer_class: Replay buffer class

230

replay_buffer_kwargs: Additional replay buffer arguments

231

optimize_memory_usage: Enable memory optimizations

232

n_steps: Number of steps for n-step return calculation

233

policy_delay: Policy update delay (TD3 specific)

234

target_policy_noise: Noise added to target policy

235

target_noise_clip: Range to clip target policy noise

236

stats_window_size: Window size for rollout logging averaging

237

tensorboard_log: Path to TensorBoard log directory

238

policy_kwargs: Additional arguments for policy construction

239

verbose: Verbosity level

240

seed: Seed for random number generator

241

device: PyTorch device placement

242

_init_setup_model: Whether to build network at creation

243

"""

244

def __init__(

245

self,

246

policy: Union[str, Type[TD3Policy]],

247

env: Union[GymEnv, str],

248

learning_rate: Union[float, Schedule] = 1e-3,

249

buffer_size: int = 1_000_000,

250

learning_starts: int = 100,

251

batch_size: int = 256,

252

tau: float = 0.005,

253

gamma: float = 0.99,

254

train_freq: Union[int, Tuple[int, str]] = 1,

255

gradient_steps: int = 1,

256

action_noise: Optional[ActionNoise] = None,

257

replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

258

replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

259

optimize_memory_usage: bool = False,

260

n_steps: int = 1,

261

policy_delay: int = 2,

262

target_policy_noise: float = 0.2,

263

target_noise_clip: float = 0.5,

264

stats_window_size: int = 100,

265

tensorboard_log: Optional[str] = None,

266

policy_kwargs: Optional[Dict[str, Any]] = None,

267

verbose: int = 0,

268

seed: Optional[int] = None,

269

device: Union[torch.device, str] = "auto",

270

_init_setup_model: bool = True,

271

): ...

272

```

273

274

### Deep Deterministic Policy Gradient (DDPG)

275

276

Off-policy algorithm for continuous control that combines DQN with policy gradients. Implemented as a special case of TD3 without the twin critics and delayed updates.

277

278

```python { .api }

279

class DDPG(TD3):

280

"""

281

Deep Deterministic Policy Gradient algorithm.

282

283

Args:

284

Same as TD3 but with different default values:

285

- policy_delay: 1 (immediate policy updates)

286

- target_policy_noise: 0.0 (no target policy noise)

287

- target_noise_clip: 0.0 (no noise clipping)

288

"""

289

def __init__(

290

self,

291

policy: Union[str, Type[TD3Policy]],

292

env: Union[GymEnv, str],

293

learning_rate: Union[float, Schedule] = 1e-4,

294

buffer_size: int = 1_000_000,

295

learning_starts: int = 100,

296

batch_size: int = 100,

297

tau: float = 0.005,

298

gamma: float = 0.99,

299

train_freq: Union[int, Tuple[int, str]] = (1, "episode"),

300

gradient_steps: int = -1,

301

action_noise: Optional[ActionNoise] = None,

302

replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

303

replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

304

optimize_memory_usage: bool = False,

305

tensorboard_log: Optional[str] = None,

306

policy_kwargs: Optional[Dict[str, Any]] = None,

307

verbose: int = 0,

308

seed: Optional[int] = None,

309

device: Union[torch.device, str] = "auto",

310

_init_setup_model: bool = True,

311

): ...

312

```

313

314

### Deep Q-Network (DQN)

315

316

Off-policy value-based algorithm for discrete action spaces. Uses experience replay and target networks to stabilize learning of Q-values.

317

318

```python { .api }

319

class DQN(OffPolicyAlgorithm):

320

"""

321

Deep Q-Network algorithm.

322

323

Args:

324

policy: Policy class or string ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")

325

env: Environment or environment ID

326

learning_rate: Learning rate, can be a function of remaining progress

327

buffer_size: Size of replay buffer

328

learning_starts: Steps before learning starts

329

batch_size: Minibatch size for training

330

tau: Soft update coefficient (1.0 = hard update)

331

gamma: Discount factor

332

train_freq: Update policy every n steps

333

gradient_steps: Gradient steps per update

334

replay_buffer_class: Replay buffer class

335

replay_buffer_kwargs: Additional replay buffer arguments

336

optimize_memory_usage: Enable memory optimizations

337

n_steps: Number of steps for n-step return calculation

338

target_update_interval: Hard update interval for target network

339

exploration_fraction: Fraction of training for exploration decay

340

exploration_initial_eps: Initial exploration probability

341

exploration_final_eps: Final exploration probability

342

max_grad_norm: Maximum gradient norm

343

stats_window_size: Window size for rollout logging averaging

344

tensorboard_log: Path to TensorBoard log directory

345

policy_kwargs: Additional arguments for policy construction

346

verbose: Verbosity level

347

seed: Seed for random number generator

348

device: PyTorch device placement

349

_init_setup_model: Whether to build network at creation

350

"""

351

def __init__(

352

self,

353

policy: Union[str, Type[DQNPolicy]],

354

env: Union[GymEnv, str],

355

learning_rate: Union[float, Schedule] = 1e-4,

356

buffer_size: int = 1_000_000,

357

learning_starts: int = 100,

358

batch_size: int = 32,

359

tau: float = 1.0,

360

gamma: float = 0.99,

361

train_freq: Union[int, Tuple[int, str]] = 4,

362

gradient_steps: int = 1,

363

replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

364

replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

365

optimize_memory_usage: bool = False,

366

n_steps: int = 1,

367

target_update_interval: int = 10000,

368

exploration_fraction: float = 0.1,

369

exploration_initial_eps: float = 1.0,

370

exploration_final_eps: float = 0.05,

371

max_grad_norm: float = 10,

372

stats_window_size: int = 100,

373

tensorboard_log: Optional[str] = None,

374

policy_kwargs: Optional[Dict[str, Any]] = None,

375

verbose: int = 0,

376

seed: Optional[int] = None,

377

device: Union[torch.device, str] = "auto",

378

_init_setup_model: bool = True,

379

): ...

380

```

381

382

## Policy Types

383

384

All algorithms support three standard policy architectures that can be specified by string or class:

385

386

```python { .api }

387

# Multi-layer perceptron policy for vector observations

388

MlpPolicy = "MlpPolicy"

389

390

# Convolutional neural network policy for image observations

391

CnnPolicy = "CnnPolicy"

392

393

# Multi-input policy for dictionary observations

394

MultiInputPolicy = "MultiInputPolicy"

395

```

396

397

## Usage Examples

398

399

### Basic Algorithm Training

400

401

```python

402

import gymnasium as gym

403

from stable_baselines3 import PPO

404

405

# Create environment and agent

406

env = gym.make("CartPole-v1")

407

model = PPO("MlpPolicy", env, verbose=1)

408

409

# Train the agent

410

model.learn(total_timesteps=25000)

411

412

# Use the trained agent

413

obs, info = env.reset()

414

for i in range(1000):

415

action, _states = model.predict(obs, deterministic=True)

416

obs, reward, terminated, truncated, info = env.step(action)

417

if terminated or truncated:

418

obs, info = env.reset()

419

```

420

421

### Custom Policy Networks

422

423

```python

424

from stable_baselines3 import SAC

425

426

# Custom policy architecture

427

policy_kwargs = dict(

428

net_arch=dict(pi=[400, 300], qf=[400, 300]),

429

activation_fn=torch.nn.ReLU,

430

)

431

432

model = SAC(

433

"MlpPolicy",

434

env,

435

policy_kwargs=policy_kwargs,

436

learning_rate=3e-4,

437

buffer_size=1000000,

438

batch_size=256,

439

verbose=1

440

)

441

```

442

443

### Continuous Control with Noise

444

445

```python

446

import numpy as np

447

from stable_baselines3 import TD3

448

from stable_baselines3.common.noise import NormalActionNoise

449

450

# Create action noise for exploration

451

n_actions = env.action_space.shape[-1]

452

action_noise = NormalActionNoise(

453

mean=np.zeros(n_actions),

454

sigma=0.1 * np.ones(n_actions)

455

)

456

457

model = TD3(

458

"MlpPolicy",

459

env,

460

action_noise=action_noise,

461

verbose=1

462

)

463

464

model.learn(total_timesteps=100000)

465

```

466

467

## Types

468

469

```python { .api }

470

from typing import Union, Optional, Type, Callable, Dict, Any, Tuple

471

import numpy as np

472

import torch

473

import gymnasium as gym

474

from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback

475

from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy

476

from stable_baselines3.common.base_class import BaseAlgorithm

477

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm

478

from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm

479

from stable_baselines3.common.noise import ActionNoise

480

from stable_baselines3.common.buffers import RolloutBuffer, ReplayBuffer

481

```