or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

algorithms.mdcommon-framework.mdenvironments.mdher.mdindex.mdtraining-utilities.md

common-framework.mddocs/

0

# Common Framework

1

2

Base classes, policies, and buffers that provide the foundation for all algorithms and enable consistent behavior across the Stable Baselines3 library. This framework promotes code reuse and ensures uniform interfaces across different algorithm implementations.

3

4

## Capabilities

5

6

### Base Algorithm Classes

7

8

Abstract base classes that define the core functionality shared by all reinforcement learning algorithms, including training loops, model management, and prediction interfaces.

9

10

```python { .api }

11

class BaseAlgorithm:

12

"""

13

Abstract base class for all RL algorithms.

14

15

Args:

16

policy: Policy class or string identifier

17

env: Environment or environment ID

18

learning_rate: Learning rate for optimization

19

policy_kwargs: Additional arguments for policy construction

20

stats_window_size: Window size for rollout logging averaging

21

tensorboard_log: Path to TensorBoard log directory

22

verbose: Verbosity level (0: no output, 1: info, 2: debug)

23

device: PyTorch device placement ("auto", "cpu", "cuda")

24

support_multi_env: Whether algorithm supports multiple environments

25

monitor_wrapper: Whether to wrap environment with Monitor

26

seed: Random seed for reproducibility

27

use_sde: Whether to use State Dependent Exploration

28

sde_sample_freq: Sample frequency for SDE

29

supported_action_spaces: List of supported action spaces

30

"""

31

def __init__(

32

self,

33

policy: Union[str, Type[BasePolicy]],

34

env: Union[GymEnv, str],

35

learning_rate: Union[float, Schedule],

36

policy_kwargs: Optional[Dict[str, Any]] = None,

37

stats_window_size: int = 100,

38

tensorboard_log: Optional[str] = None,

39

verbose: int = 0,

40

device: Union[torch.device, str] = "auto",

41

support_multi_env: bool = False,

42

monitor_wrapper: bool = True,

43

seed: Optional[int] = None,

44

use_sde: bool = False,

45

sde_sample_freq: int = -1,

46

supported_action_spaces: Optional[Tuple[Type[gym.Space], ...]] = None,

47

): ...

48

49

def learn(

50

self,

51

total_timesteps: int,

52

callback: MaybeCallback = None,

53

log_interval: int = 4,

54

tb_log_name: str = "run",

55

reset_num_timesteps: bool = True,

56

progress_bar: bool = False,

57

) -> "BaseAlgorithm":

58

"""

59

Train the agent for total_timesteps.

60

61

Args:

62

total_timesteps: Total number of timesteps to train

63

callback: Callback(s) called during training

64

log_interval: Log interval for training metrics

65

tb_log_name: TensorBoard log name

66

reset_num_timesteps: Reset timestep counter

67

progress_bar: Display progress bar

68

69

Returns:

70

Trained algorithm instance

71

"""

72

73

def predict(

74

self,

75

observation: Union[np.ndarray, Dict[str, np.ndarray]],

76

state: Optional[Tuple[np.ndarray, ...]] = None,

77

episode_start: Optional[np.ndarray] = None,

78

deterministic: bool = False,

79

) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:

80

"""

81

Get action from observation.

82

83

Args:

84

observation: Input observation

85

state: Hidden state for recurrent policies

86

episode_start: Start of episode mask

87

deterministic: Use deterministic actions

88

89

Returns:

90

Tuple of (action, next_state)

91

"""

92

93

def save(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:

94

"""Save model to file path."""

95

96

@classmethod

97

def load(

98

cls,

99

path: Union[str, pathlib.Path, io.BufferedIOBase],

100

env: Optional[GymEnv] = None,

101

device: Union[torch.device, str] = "auto",

102

custom_objects: Optional[Dict[str, Any]] = None,

103

print_system_info: bool = False,

104

force_reset: bool = True,

105

**kwargs,

106

) -> "BaseAlgorithm":

107

"""Load model from file path."""

108

109

def set_env(self, env: GymEnv, force_reset: bool = True) -> None:

110

"""Set new environment for the algorithm."""

111

112

def get_env(self) -> Optional[VecEnv]:

113

"""Get current environment."""

114

115

def set_random_seed(self, seed: Optional[int] = None) -> None:

116

"""Set random seed for reproducibility."""

117

118

class OnPolicyAlgorithm(BaseAlgorithm):

119

"""

120

Base class for on-policy algorithms (A2C, PPO).

121

122

Additional Args:

123

n_steps: Number of steps per environment per update

124

gamma: Discount factor

125

gae_lambda: GAE lambda parameter

126

ent_coef: Entropy coefficient

127

vf_coef: Value function coefficient

128

max_grad_norm: Maximum gradient norm for clipping

129

"""

130

def __init__(

131

self,

132

policy: Union[str, Type[ActorCriticPolicy]],

133

env: Union[GymEnv, str],

134

learning_rate: Union[float, Schedule],

135

n_steps: int,

136

gamma: float,

137

gae_lambda: float,

138

ent_coef: float,

139

vf_coef: float,

140

max_grad_norm: float,

141

use_sde: bool = False,

142

sde_sample_freq: int = -1,

143

rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,

144

rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,

145

**kwargs,

146

): ...

147

148

def collect_rollouts(

149

self,

150

env: VecEnv,

151

callback: BaseCallback,

152

rollout_buffer: RolloutBuffer,

153

n_rollout_steps: int,

154

) -> bool:

155

"""Collect rollout data from environment."""

156

157

class OffPolicyAlgorithm(BaseAlgorithm):

158

"""

159

Base class for off-policy algorithms (SAC, TD3, DDPG, DQN).

160

161

Additional Args:

162

buffer_size: Replay buffer size

163

learning_starts: Steps before learning starts

164

batch_size: Minibatch size for training

165

tau: Soft update coefficient for target networks

166

train_freq: Training frequency

167

gradient_steps: Gradient steps per training

168

action_noise: Action noise for exploration

169

replay_buffer_class: Replay buffer class

170

replay_buffer_kwargs: Additional buffer arguments

171

optimize_memory_usage: Enable memory optimizations

172

"""

173

def __init__(

174

self,

175

policy: Union[str, Type[BasePolicy]],

176

env: Union[GymEnv, str],

177

learning_rate: Union[float, Schedule],

178

buffer_size: int = 1_000_000,

179

learning_starts: int = 100,

180

batch_size: int = 256,

181

tau: float = 0.005,

182

gamma: float = 0.99,

183

train_freq: Union[int, Tuple[int, str]] = (1, "step"),

184

gradient_steps: int = 1,

185

action_noise: Optional[ActionNoise] = None,

186

replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

187

replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

188

optimize_memory_usage: bool = False,

189

**kwargs,

190

): ...

191

192

def _sample_action(

193

self,

194

learning_starts: int,

195

action_noise: Optional[ActionNoise] = None,

196

n_envs: int = 1,

197

) -> Tuple[np.ndarray, np.ndarray]:

198

"""Sample action with exploration noise."""

199

```

200

201

### Policy Base Classes

202

203

Neural network architectures that define how observations are processed and actions are selected, supporting different observation spaces and algorithm requirements.

204

205

```python { .api }

206

class BaseModel(torch.nn.Module):

207

"""

208

Base class for all neural network models.

209

210

Args:

211

observation_space: Observation space

212

action_space: Action space

213

lr_schedule: Learning rate schedule

214

use_sde: Whether to use State Dependent Exploration

215

log_std_init: Initial log standard deviation

216

full_std: Use full covariance matrix for SDE

217

sde_net_arch: Network architecture for SDE

218

use_expln: Use exponential activation for variance

219

squash_output: Squash output with tanh

220

features_extractor_class: Feature extractor class

221

features_extractor_kwargs: Feature extractor arguments

222

share_features_extractor: Share feature extractor between actor/critic

223

normalize_images: Normalize image observations

224

optimizer_class: Optimizer class

225

optimizer_kwargs: Optimizer arguments

226

"""

227

def __init__(

228

self,

229

observation_space: gym.spaces.Space,

230

action_space: gym.spaces.Space,

231

lr_schedule: Schedule,

232

use_sde: bool = False,

233

log_std_init: float = 0.0,

234

full_std: bool = True,

235

sde_net_arch: Optional[List[int]] = None,

236

use_expln: bool = False,

237

squash_output: bool = False,

238

features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,

239

features_extractor_kwargs: Optional[Dict[str, Any]] = None,

240

share_features_extractor: bool = True,

241

normalize_images: bool = True,

242

optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,

243

optimizer_kwargs: Optional[Dict[str, Any]] = None,

244

): ...

245

246

def forward(self, *args, **kwargs) -> torch.Tensor:

247

"""Forward pass through the network."""

248

249

class BasePolicy(BaseModel):

250

"""

251

Base policy class for all algorithms.

252

253

Args:

254

observation_space: Observation space

255

action_space: Action space

256

lr_schedule: Learning rate schedule

257

use_sde: Whether to use State Dependent Exploration

258

**kwargs: Additional arguments passed to BaseModel

259

"""

260

def __init__(

261

self,

262

observation_space: gym.spaces.Space,

263

action_space: gym.spaces.Space,

264

lr_schedule: Schedule,

265

use_sde: bool = False,

266

**kwargs,

267

): ...

268

269

def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:

270

"""Get action from observation."""

271

272

def predict(

273

self,

274

observation: Union[np.ndarray, Dict[str, np.ndarray]],

275

state: Optional[Tuple[np.ndarray, ...]] = None,

276

episode_start: Optional[np.ndarray] = None,

277

deterministic: bool = False,

278

) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:

279

"""Get action and state from observation."""

280

281

def _predict(

282

self, observation: torch.Tensor, deterministic: bool = False

283

) -> torch.Tensor:

284

"""Internal prediction method."""

285

286

class ActorCriticPolicy(BasePolicy):

287

"""

288

Policy with both actor and critic networks for on-policy algorithms.

289

290

Args:

291

observation_space: Observation space

292

action_space: Action space

293

lr_schedule: Learning rate schedule

294

net_arch: Network architecture specification

295

activation_fn: Activation function

296

ortho_init: Use orthogonal initialization

297

use_sde: Whether to use State Dependent Exploration

298

log_std_init: Initial log standard deviation

299

full_std: Use full covariance matrix for SDE

300

sde_net_arch: Network architecture for SDE

301

use_expln: Use exponential activation for variance

302

squash_output: Squash output with tanh

303

features_extractor_class: Feature extractor class

304

features_extractor_kwargs: Feature extractor arguments

305

share_features_extractor: Share feature extractor between actor/critic

306

normalize_images: Normalize image observations

307

optimizer_class: Optimizer class

308

optimizer_kwargs: Optimizer arguments

309

"""

310

def __init__(

311

self,

312

observation_space: gym.spaces.Space,

313

action_space: gym.spaces.Space,

314

lr_schedule: Schedule,

315

net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,

316

activation_fn: Type[torch.nn.Module] = torch.nn.Tanh,

317

ortho_init: bool = True,

318

use_sde: bool = False,

319

log_std_init: float = 0.0,

320

full_std: bool = True,

321

sde_net_arch: Optional[List[int]] = None,

322

use_expln: bool = False,

323

squash_output: bool = False,

324

features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,

325

features_extractor_kwargs: Optional[Dict[str, Any]] = None,

326

share_features_extractor: bool = True,

327

normalize_images: bool = True,

328

optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,

329

optimizer_kwargs: Optional[Dict[str, Any]] = None,

330

): ...

331

332

def forward(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:

333

"""Forward pass through actor network."""

334

335

def evaluate_actions(

336

self, obs: torch.Tensor, actions: torch.Tensor

337

) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

338

"""Evaluate actions for training."""

339

340

def get_distribution(self, obs: torch.Tensor) -> Distribution:

341

"""Get action distribution from observation."""

342

343

def predict_values(self, obs: torch.Tensor) -> torch.Tensor:

344

"""Predict state values using critic network."""

345

```

346

347

### Common Policy Aliases

348

349

All algorithms provide convenient aliases for their policy classes to simplify usage:

350

351

```python { .api }

352

# Standard policy aliases (available in each algorithm module)

353

MlpPolicy = ActorCriticPolicy # For algorithms like A2C, PPO

354

CnnPolicy = ActorCriticCnnPolicy # For image-based observations

355

MultiInputPolicy = MultiInputActorCriticPolicy # For dict observations

356

357

# Import examples:

358

from stable_baselines3.ppo import MlpPolicy as PPOMlpPolicy

359

from stable_baselines3.a2c import CnnPolicy as A2CCnnPolicy

360

from stable_baselines3.sac import MlpPolicy as SACMlpPolicy

361

```

362

363

### Experience Buffers

364

365

Storage mechanisms for training data that enable different sampling strategies and memory management approaches for various algorithm types.

366

367

```python { .api }

368

class BaseBuffer:

369

"""

370

Abstract base class for all experience buffers.

371

372

Args:

373

buffer_size: Maximum buffer capacity

374

observation_space: Observation space

375

action_space: Action space

376

device: PyTorch device placement

377

n_envs: Number of parallel environments

378

"""

379

def __init__(

380

self,

381

buffer_size: int,

382

observation_space: gym.spaces.Space,

383

action_space: gym.spaces.Space,

384

device: Union[torch.device, str] = "auto",

385

n_envs: int = 1,

386

): ...

387

388

def add(self, *args, **kwargs) -> None:

389

"""Add experience to buffer."""

390

391

def get(self, *args, **kwargs) -> Any:

392

"""Sample experience from buffer."""

393

394

def reset(self) -> None:

395

"""Reset buffer to empty state."""

396

397

def size(self) -> int:

398

"""Current buffer size."""

399

400

class RolloutBuffer(BaseBuffer):

401

"""

402

Buffer for on-policy algorithms that stores rollout trajectories.

403

404

Args:

405

buffer_size: Buffer capacity (typically n_steps * n_envs)

406

observation_space: Observation space

407

action_space: Action space

408

device: PyTorch device placement

409

gae_lambda: GAE lambda parameter

410

gamma: Discount factor

411

n_envs: Number of parallel environments

412

"""

413

def __init__(

414

self,

415

buffer_size: int,

416

observation_space: gym.spaces.Space,

417

action_space: gym.spaces.Space,

418

device: Union[torch.device, str] = "auto",

419

gae_lambda: float = 1,

420

gamma: float = 0.99,

421

n_envs: int = 1,

422

): ...

423

424

def add(

425

self,

426

obs: np.ndarray,

427

actions: np.ndarray,

428

rewards: np.ndarray,

429

episode_starts: np.ndarray,

430

values: torch.Tensor,

431

log_probs: torch.Tensor,

432

) -> None:

433

"""

434

Add rollout data to buffer.

435

436

Args:

437

obs: Observations

438

actions: Actions taken

439

rewards: Rewards received

440

episode_starts: Episode start flags

441

values: State value estimates

442

log_probs: Action log probabilities

443

"""

444

445

def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:

446

"""

447

Sample batches from buffer.

448

449

Args:

450

batch_size: Size of batches to sample

451

452

Yields:

453

Batches of rollout data

454

"""

455

456

def compute_returns_and_advantage(

457

self, last_values: torch.Tensor, dones: np.ndarray

458

) -> None:

459

"""

460

Compute returns and advantages using GAE.

461

462

Args:

463

last_values: Value estimates for final states

464

dones: Episode termination flags

465

"""

466

467

class ReplayBuffer(BaseBuffer):

468

"""

469

Experience replay buffer for off-policy algorithms.

470

471

Args:

472

buffer_size: Maximum buffer capacity

473

observation_space: Observation space

474

action_space: Action space

475

device: PyTorch device placement

476

n_envs: Number of parallel environments

477

optimize_memory_usage: Enable memory optimizations

478

handle_timeout_termination: Handle timeout terminations properly

479

"""

480

def __init__(

481

self,

482

buffer_size: int,

483

observation_space: gym.spaces.Space,

484

action_space: gym.spaces.Space,

485

device: Union[torch.device, str] = "auto",

486

n_envs: int = 1,

487

optimize_memory_usage: bool = False,

488

handle_timeout_termination: bool = True,

489

): ...

490

491

def add(

492

self,

493

obs: np.ndarray,

494

next_obs: np.ndarray,

495

actions: np.ndarray,

496

rewards: np.ndarray,

497

dones: np.ndarray,

498

infos: List[Dict[str, Any]],

499

) -> None:

500

"""

501

Add transition to replay buffer.

502

503

Args:

504

obs: Current observations

505

next_obs: Next observations

506

actions: Actions taken

507

rewards: Rewards received

508

dones: Episode termination flags

509

infos: Additional information

510

"""

511

512

def sample(self, batch_size: int, env: Optional[VecEnv] = None) -> ReplayBufferSamples:

513

"""

514

Sample batch of transitions.

515

516

Args:

517

batch_size: Number of transitions to sample

518

env: Environment for normalization

519

520

Returns:

521

Batch of experience samples

522

"""

523

524

class DictRolloutBuffer(RolloutBuffer):

525

"""Rollout buffer for dictionary observations."""

526

527

class DictReplayBuffer(ReplayBuffer):

528

"""Replay buffer for dictionary observations."""

529

```

530

531

## Usage Examples

532

533

### Custom Policy Architecture

534

535

```python

536

from stable_baselines3 import PPO

537

from stable_baselines3.common.policies import ActorCriticPolicy

538

import torch.nn as nn

539

540

# Define custom network architecture

541

policy_kwargs = dict(

542

net_arch=[dict(pi=[128, 128], vf=[128, 128])],

543

activation_fn=nn.ReLU,

544

ortho_init=True,

545

)

546

547

model = PPO(

548

"MlpPolicy",

549

env,

550

policy_kwargs=policy_kwargs,

551

verbose=1

552

)

553

```

554

555

### Custom Buffer Configuration

556

557

```python

558

from stable_baselines3 import SAC

559

from stable_baselines3.common.buffers import ReplayBuffer

560

561

# Custom replay buffer settings

562

replay_buffer_kwargs = dict(

563

optimize_memory_usage=True,

564

handle_timeout_termination=True,

565

)

566

567

model = SAC(

568

"MlpPolicy",

569

env,

570

buffer_size=500000,

571

replay_buffer_kwargs=replay_buffer_kwargs,

572

verbose=1

573

)

574

```

575

576

### Accessing Buffer Data

577

578

```python

579

# Access replay buffer after training

580

replay_buffer = model.replay_buffer

581

582

# Sample transitions for analysis

583

batch = replay_buffer.sample(batch_size=256)

584

observations = batch.observations

585

actions = batch.actions

586

rewards = batch.rewards

587

```

588

589

## Types

590

591

```python { .api }

592

from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple

593

import numpy as np

594

import torch

595

import gymnasium as gym

596

import pathlib

597

import io

598

from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback, PyTorchObs, TensorDict

599

from stable_baselines3.common.policies import BasePolicy, ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy

600

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

601

from stable_baselines3.common.noise import ActionNoise

602

from stable_baselines3.common.vec_env import VecEnv

603

```