or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

algorithms.mdcommon-framework.mdenvironments.mdher.mdindex.mdtraining-utilities.md

training-utilities.mddocs/

0

# Training Utilities

1

2

Callbacks, noise generators, evaluation tools, and other utilities to enhance and monitor training processes. These components provide essential functionality for experiment management, hyperparameter tuning, and production deployment of RL systems.

3

4

## Capabilities

5

6

### Callback System

7

8

Event-driven system for monitoring, evaluating, and controlling training processes with customizable hooks at various training stages.

9

10

```python { .api }

11

class BaseCallback:

12

"""

13

Abstract base class for training callbacks.

14

15

Args:

16

verbose: Verbosity level (0: quiet, 1: info, 2: debug)

17

"""

18

def __init__(self, verbose: int = 0): ...

19

20

def init_callback(self, model: "BaseAlgorithm") -> None:

21

"""Initialize callback with algorithm instance."""

22

23

def on_training_start(

24

self, locals_: Dict[str, Any], globals_: Dict[str, Any]

25

) -> None:

26

"""Called when training begins."""

27

28

def on_rollout_start(self) -> None:

29

"""Called before collecting rollouts."""

30

31

def on_step(self) -> bool:

32

"""

33

Called after each environment step.

34

35

Returns:

36

True to continue training, False to stop

37

"""

38

39

def on_rollout_end(self) -> None:

40

"""Called after rollout collection."""

41

42

def on_training_end(self) -> None:

43

"""Called when training ends."""

44

45

def update_locals(self, locals_: Dict[str, Any]) -> None:

46

"""Update callback with current local variables."""

47

48

class EventCallback(BaseCallback):

49

"""

50

Base class for event-triggered callbacks.

51

52

Args:

53

callback: Child callback to trigger

54

verbose: Verbosity level

55

"""

56

def __init__(self, callback: Optional["BaseCallback"] = None, verbose: int = 0): ...

57

58

def _trigger_event(self) -> bool:

59

"""Trigger child callback if conditions are met."""

60

61

def _on_event(self) -> bool:

62

"""Event handler (to be implemented by subclasses)."""

63

64

class CallbackList(BaseCallback):

65

"""

66

Container for multiple callbacks.

67

68

Args:

69

callbacks: List of callback instances

70

"""

71

def __init__(self, callbacks: List[BaseCallback]): ...

72

73

def on_training_start(

74

self, locals_: Dict[str, Any], globals_: Dict[str, Any]

75

) -> None:

76

"""Call on_training_start for all callbacks."""

77

78

def on_step(self) -> bool:

79

"""Call on_step for all callbacks, stop if any returns False."""

80

81

class EvalCallback(EventCallback):

82

"""

83

Evaluate agent during training and save best model.

84

85

Args:

86

eval_env: Environment for evaluation

87

callback_on_new_best: Callback triggered when new best model found

88

callback_after_eval: Callback triggered after evaluation

89

n_eval_episodes: Number of episodes for evaluation

90

eval_freq: Evaluation frequency (steps)

91

log_path: Path for evaluation logs

92

best_model_save_path: Path to save best model

93

deterministic: Use deterministic actions during evaluation

94

render: Render evaluation episodes

95

verbose: Verbosity level

96

warn: Show warnings for evaluation issues

97

"""

98

def __init__(

99

self,

100

eval_env: Union[gym.Env, VecEnv],

101

callback_on_new_best: Optional[BaseCallback] = None,

102

callback_after_eval: Optional[BaseCallback] = None,

103

n_eval_episodes: int = 5,

104

eval_freq: int = 10000,

105

log_path: Optional[str] = None,

106

best_model_save_path: Optional[str] = None,

107

deterministic: bool = True,

108

render: bool = False,

109

verbose: int = 1,

110

warn: bool = True,

111

): ...

112

113

def _on_step(self) -> bool:

114

"""Evaluate model if eval_freq steps have passed."""

115

116

def _on_event(self) -> bool:

117

"""Perform evaluation and save best model."""

118

119

class CheckpointCallback(BaseCallback):

120

"""

121

Save model at regular intervals.

122

123

Args:

124

save_freq: Frequency for saving checkpoints (steps)

125

save_path: Directory to save checkpoints

126

name_prefix: Prefix for checkpoint filenames

127

save_replay_buffer: Whether to save replay buffer

128

save_vecnormalize: Whether to save VecNormalize statistics

129

verbose: Verbosity level

130

"""

131

def __init__(

132

self,

133

save_freq: int,

134

save_path: str,

135

name_prefix: str = "rl_model",

136

save_replay_buffer: bool = False,

137

save_vecnormalize: bool = False,

138

verbose: int = 0,

139

): ...

140

141

def _on_step(self) -> bool:

142

"""Save checkpoint if save_freq steps have passed."""

143

144

class StopTrainingOnRewardThreshold(BaseCallback):

145

"""

146

Stop training when reward threshold is reached.

147

148

Args:

149

reward_threshold: Minimum average reward to stop training

150

verbose: Verbosity level

151

"""

152

def __init__(self, reward_threshold: float, verbose: int = 0): ...

153

154

def _on_step(self) -> bool:

155

"""Check if reward threshold is reached."""

156

157

class StopTrainingOnMaxEpisodes(BaseCallback):

158

"""

159

Stop training after maximum number of episodes.

160

161

Args:

162

max_episodes: Maximum number of episodes

163

verbose: Verbosity level

164

"""

165

def __init__(self, max_episodes: int, verbose: int = 0): ...

166

167

def _on_step(self) -> bool:

168

"""Check if maximum episodes reached."""

169

170

class ProgressBarCallback(BaseCallback):

171

"""

172

Display training progress bar using tqdm.

173

174

Args:

175

refresh_freq: Progress bar refresh frequency

176

"""

177

def __init__(self, refresh_freq: int = 1): ...

178

179

def on_training_start(

180

self, locals_: Dict[str, Any], globals_: Dict[str, Any]

181

) -> None:

182

"""Initialize progress bar."""

183

184

def _on_step(self) -> bool:

185

"""Update progress bar."""

186

187

class EveryNTimesteps(EventCallback):

188

"""

189

Trigger a callback every n timesteps.

190

191

Args:

192

n_steps: Number of timesteps between triggers

193

callback: Callback to trigger

194

"""

195

def __init__(self, n_steps: int, callback: BaseCallback): ...

196

197

class ConvertCallback(BaseCallback):

198

"""

199

Convert functional callback (old-style) to object.

200

201

Args:

202

callback: Optional callback function

203

verbose: Verbosity level

204

"""

205

def __init__(self, callback: Optional[Callable], verbose: int = 0): ...

206

207

class StopTrainingOnNoModelImprovement(BaseCallback):

208

"""

209

Stop training if no new best model after N consecutive evaluations.

210

Must be used with EvalCallback.

211

212

Args:

213

max_no_improvement_evals: Max consecutive evaluations without improvement

214

min_evals: Number of evaluations before counting

215

verbose: Verbosity level

216

"""

217

def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): ...

218

```

219

220

### Noise Classes

221

222

Action noise generators for exploration in continuous control environments, providing various stochastic processes for effective exploration strategies.

223

224

```python { .api }

225

class ActionNoise:

226

"""Abstract base class for action noise."""

227

228

def __call__(self) -> np.ndarray:

229

"""Generate noise sample."""

230

231

def reset(self) -> None:

232

"""Reset noise state."""

233

234

class NormalActionNoise(ActionNoise):

235

"""

236

Gaussian action noise for exploration.

237

238

Args:

239

mean: Mean of the noise distribution

240

sigma: Standard deviation of the noise distribution

241

"""

242

def __init__(self, mean: np.ndarray, sigma: np.ndarray): ...

243

244

def __call__(self) -> np.ndarray:

245

"""Sample from Gaussian distribution."""

246

247

def reset(self) -> None:

248

"""Reset noise (no-op for memoryless noise)."""

249

250

class OrnsteinUhlenbeckActionNoise(ActionNoise):

251

"""

252

Ornstein-Uhlenbeck process noise for temporally correlated exploration.

253

254

Args:

255

mean: Long-run mean of the process

256

sigma: Volatility parameter

257

theta: Rate of mean reversion

258

dt: Time step

259

initial_noise: Initial noise value

260

"""

261

def __init__(

262

self,

263

mean: np.ndarray,

264

sigma: np.ndarray,

265

theta: float = 0.15,

266

dt: float = 1e-2,

267

initial_noise: Optional[np.ndarray] = None,

268

): ...

269

270

def __call__(self) -> np.ndarray:

271

"""Sample next noise value from OU process."""

272

273

def reset(self) -> None:

274

"""Reset process to initial state."""

275

276

class VectorizedActionNoise(ActionNoise):

277

"""

278

Vectorized noise for multiple environments.

279

280

Args:

281

noise_fn: Function to create noise instances

282

n_envs: Number of environments

283

"""

284

def __init__(self, noise_fn: Callable[[], ActionNoise], n_envs: int): ...

285

286

def __call__(self) -> np.ndarray:

287

"""Sample noise for all environments."""

288

289

def reset(self) -> None:

290

"""Reset all noise instances."""

291

```

292

293

### Evaluation Functions

294

295

Comprehensive evaluation utilities for assessing trained agents, including statistical analysis and performance monitoring across multiple episodes.

296

297

```python { .api }

298

def evaluate_policy(

299

model: "BaseAlgorithm",

300

env: Union[gym.Env, VecEnv],

301

n_eval_episodes: int = 10,

302

deterministic: bool = True,

303

render: bool = False,

304

callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,

305

reward_threshold: Optional[float] = None,

306

return_episode_rewards: bool = False,

307

warn: bool = True,

308

verbose: int = 1,

309

) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:

310

"""

311

Evaluate trained agent on environment.

312

313

Args:

314

model: Trained RL algorithm

315

env: Environment for evaluation

316

n_eval_episodes: Number of episodes to evaluate

317

deterministic: Use deterministic actions

318

render: Render environment during evaluation

319

callback: Custom callback function

320

reward_threshold: Minimum reward threshold for success

321

return_episode_rewards: Return individual episode rewards

322

warn: Show warnings

323

verbose: Verbosity level

324

325

Returns:

326

Mean reward and standard deviation, or episode rewards and lengths

327

"""

328

```

329

330

### Utility Functions

331

332

General-purpose utilities for reproducibility, tensor operations, mathematical computations, and other common tasks in RL training.

333

334

```python { .api }

335

def set_random_seed(seed: int, using_cuda: bool = False) -> None:

336

"""

337

Set random seed for reproducibility.

338

339

Args:

340

seed: Random seed value

341

using_cuda: Whether CUDA is being used

342

"""

343

344

def get_system_info() -> Dict[str, Any]:

345

"""

346

Get system information for debugging.

347

348

Returns:

349

Dictionary containing system information

350

"""

351

352

def get_device(device: Union[torch.device, str] = "auto") -> torch.device:

353

"""

354

Get PyTorch device from string specification.

355

356

Args:

357

device: Device specification ("auto", "cpu", "cuda", etc.)

358

359

Returns:

360

PyTorch device object

361

"""

362

363

def obs_as_tensor(

364

obs: Union[np.ndarray, Dict[str, np.ndarray]], device: torch.device

365

) -> Union[torch.Tensor, TensorDict]:

366

"""

367

Convert observations to PyTorch tensors.

368

369

Args:

370

obs: Observations to convert

371

device: Target device for tensors

372

373

Returns:

374

Tensor or dictionary of tensors

375

"""

376

377

def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:

378

"""

379

Calculate explained variance.

380

381

Args:

382

y_pred: Predicted values

383

y_true: True values

384

385

Returns:

386

Explained variance

387

"""

388

389

def polyak_update(

390

params: Iterable[torch.nn.Parameter],

391

target_params: Iterable[torch.nn.Parameter],

392

tau: float,

393

) -> None:

394

"""

395

Polyak averaging for target network updates.

396

397

Args:

398

params: Source parameters

399

target_params: Target parameters to update

400

tau: Update coefficient (0 = no update, 1 = hard update)

401

"""

402

403

def update_learning_rate(optimizer: torch.optim.Optimizer, learning_rate: float) -> None:

404

"""

405

Update optimizer learning rate.

406

407

Args:

408

optimizer: PyTorch optimizer

409

learning_rate: New learning rate

410

"""

411

412

def safe_mean(arr: List[float]) -> np.ndarray:

413

"""

414

Calculate mean safely, handling empty arrays.

415

416

Args:

417

arr: Array of values

418

419

Returns:

420

Mean value or NaN if array is empty

421

"""

422

```

423

424

### Schedule Functions

425

426

Learning rate and parameter scheduling utilities for adaptive training dynamics and improved convergence behavior.

427

428

```python { .api }

429

def get_schedule_fn(value_schedule: Union[float, str, Schedule]) -> Schedule:

430

"""

431

Convert value to schedule function.

432

433

Args:

434

value_schedule: Constant value, string identifier, or schedule function

435

436

Returns:

437

Schedule function

438

"""

439

440

def get_linear_fn(

441

start: float, end: float, end_fraction: float

442

) -> Callable[[float], float]:

443

"""

444

Create linear schedule function.

445

446

Args:

447

start: Initial value

448

end: Final value

449

end_fraction: Fraction of training when end value is reached

450

451

Returns:

452

Linear schedule function

453

"""

454

455

def constant_fn(val: float) -> Callable[[float], float]:

456

"""

457

Create constant schedule function.

458

459

Args:

460

val: Constant value

461

462

Returns:

463

Constant schedule function

464

"""

465

```

466

467

## Usage Examples

468

469

### Comprehensive Training Setup

470

471

```python

472

import gymnasium as gym

473

from stable_baselines3 import PPO

474

from stable_baselines3.common.callbacks import (

475

EvalCallback, CheckpointCallback, StopTrainingOnRewardThreshold

476

)

477

from stable_baselines3.common.vec_env import DummyVecEnv

478

479

# Create training and evaluation environments

480

train_env = DummyVecEnv([lambda: gym.make("CartPole-v1") for _ in range(4)])

481

eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])

482

483

# Set up callbacks

484

eval_callback = EvalCallback(

485

eval_env,

486

best_model_save_path="./logs/best_model/",

487

log_path="./logs/results/",

488

eval_freq=10000,

489

n_eval_episodes=5,

490

deterministic=True,

491

render=False

492

)

493

494

checkpoint_callback = CheckpointCallback(

495

save_freq=50000,

496

save_path="./logs/checkpoints/",

497

name_prefix="ppo_cartpole",

498

save_replay_buffer=False,

499

save_vecnormalize=False

500

)

501

502

stop_callback = StopTrainingOnRewardThreshold(

503

reward_threshold=195.0, # CartPole-v1 is considered solved at 195

504

verbose=1

505

)

506

507

# Combine callbacks

508

from stable_baselines3.common.callbacks import CallbackList

509

callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])

510

511

# Train with callbacks

512

model = PPO("MlpPolicy", train_env, verbose=1)

513

model.learn(total_timesteps=100000, callback=callback)

514

```

515

516

### Action Noise for Continuous Control

517

518

```python

519

import numpy as np

520

from stable_baselines3 import TD3

521

from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

522

523

# Gaussian noise

524

n_actions = env.action_space.shape[-1]

525

action_noise = NormalActionNoise(

526

mean=np.zeros(n_actions),

527

sigma=0.1 * np.ones(n_actions)

528

)

529

530

# Ornstein-Uhlenbeck noise for temporally correlated exploration

531

action_noise = OrnsteinUhlenbeckActionNoise(

532

mean=np.zeros(n_actions),

533

sigma=0.1 * np.ones(n_actions),

534

theta=0.15,

535

dt=1e-2

536

)

537

538

model = TD3(

539

"MlpPolicy",

540

env,

541

action_noise=action_noise,

542

verbose=1

543

)

544

545

model.learn(total_timesteps=100000)

546

```

547

548

### Custom Evaluation and Analysis

549

550

```python

551

from stable_baselines3.common.evaluation import evaluate_policy

552

import matplotlib.pyplot as plt

553

554

# Detailed evaluation

555

episode_rewards, episode_lengths = evaluate_policy(

556

model,

557

eval_env,

558

n_eval_episodes=100,

559

deterministic=True,

560

return_episode_rewards=True

561

)

562

563

# Statistical analysis

564

print(f"Mean reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")

565

print(f"Mean episode length: {np.mean(episode_lengths):.2f}")

566

567

# Plot results

568

plt.figure(figsize=(12, 4))

569

plt.subplot(1, 2, 1)

570

plt.hist(episode_rewards, bins=20)

571

plt.xlabel("Episode Reward")

572

plt.ylabel("Frequency")

573

574

plt.subplot(1, 2, 2)

575

plt.hist(episode_lengths, bins=20)

576

plt.xlabel("Episode Length")

577

plt.ylabel("Frequency")

578

plt.show()

579

```

580

581

### Learning Rate Scheduling

582

583

```python

584

from stable_baselines3.common.utils import get_linear_fn

585

586

# Linear learning rate decay

587

learning_rate = get_linear_fn(3e-4, 1e-5, 1.0)

588

589

model = PPO(

590

"MlpPolicy",

591

env,

592

learning_rate=learning_rate,

593

verbose=1

594

)

595

596

# Custom schedule function

597

def custom_schedule(progress_remaining):

598

"""Custom learning rate schedule."""

599

if progress_remaining > 0.5:

600

return 3e-4

601

else:

602

return 1e-4

603

604

model = PPO(

605

"MlpPolicy",

606

env,

607

learning_rate=custom_schedule,

608

verbose=1

609

)

610

```

611

612

### Custom Callback Creation

613

614

```python

615

from stable_baselines3.common.callbacks import BaseCallback

616

617

class LoggingCallback(BaseCallback):

618

"""Custom callback for additional logging."""

619

620

def __init__(self, verbose=0):

621

super(LoggingCallback, self).__init__(verbose)

622

self.episode_rewards = []

623

624

def _on_step(self) -> bool:

625

# Access training variables

626

if len(self.model.ep_info_buffer) > 0:

627

mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer])

628

self.logger.record("custom/mean_episode_reward", mean_reward)

629

630

return True # Continue training

631

632

# Use custom callback

633

custom_callback = LoggingCallback(verbose=1)

634

model.learn(total_timesteps=100000, callback=custom_callback)

635

```

636

637

## Types

638

639

```python { .api }

640

from typing import Union, Optional, Type, Callable, Dict, Any, List, Tuple

641

import numpy as np

642

import gymnasium as gym

643

from stable_baselines3.common.callbacks import BaseCallback, EventCallback, CallbackList

644

from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback

645

from stable_baselines3.common.noise import ActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise

646

from stable_baselines3.common.vec_env import VecEnv

647

from stable_baselines3.common.base_class import BaseAlgorithm

648

```