or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md

callbacks.mddocs/

0

# Training Callbacks

1

2

Custom callbacks for training monitoring, evaluation, hyperparameter optimization, and logging. These callbacks extend Stable Baselines3's callback system with specialized functionality for RL Zoo3's training workflows.

3

4

## Core Imports

5

6

```python

7

from rl_zoo3.callbacks import (

8

TrialEvalCallback,

9

SaveVecNormalizeCallback,

10

ParallelTrainCallback,

11

RawStatisticsCallback

12

)

13

from stable_baselines3.common.callbacks import BaseCallback, EvalCallback

14

import optuna

15

```

16

17

## Capabilities

18

19

### Trial Evaluation Callback

20

21

Specialized callback for Optuna hyperparameter optimization trials, providing evaluation and pruning functionality during optimization runs.

22

23

```python { .api }

24

class TrialEvalCallback(EvalCallback):

25

"""

26

Callback used for evaluating and reporting a trial during hyperparameter optimization.

27

Extends EvalCallback with Optuna trial integration.

28

"""

29

30

def __init__(

31

self,

32

eval_env: VecEnv,

33

trial: optuna.Trial,

34

n_eval_episodes: int = 5,

35

eval_freq: int = 10000,

36

deterministic: bool = True,

37

verbose: int = 0,

38

best_model_save_path: Optional[str] = None,

39

log_path: Optional[str] = None

40

) -> None:

41

"""

42

Initialize TrialEvalCallback.

43

44

Parameters:

45

- eval_env: Vectorized evaluation environment

46

- trial: Optuna trial object for reporting results

47

- n_eval_episodes: Number of episodes for each evaluation

48

- eval_freq: Frequency of evaluation (in timesteps)

49

- deterministic: Whether to use deterministic actions during evaluation

50

- verbose: Verbosity level

51

- best_model_save_path: Path to save best model

52

- log_path: Path for evaluation logs

53

"""

54

55

def _on_step(self) -> bool:

56

"""

57

Called at each training step. Performs evaluation and reports to Optuna.

58

59

Returns:

60

bool: Whether training should continue

61

"""

62

```

63

64

Usage example:

65

```python

66

import optuna

67

from rl_zoo3.callbacks import TrialEvalCallback

68

from rl_zoo3 import create_test_env

69

70

def objective(trial):

71

# Create evaluation environment

72

eval_env = create_test_env("CartPole-v1", n_envs=1)

73

74

# Create callback

75

eval_callback = TrialEvalCallback(

76

eval_env=eval_env,

77

trial=trial,

78

n_eval_episodes=10,

79

eval_freq=1000,

80

deterministic=True,

81

verbose=0

82

)

83

84

# Use callback in model training (simplified)

85

# model.learn(total_timesteps=10000, callback=eval_callback)

86

87

return eval_callback.best_mean_reward

88

89

# Run optimization

90

study = optuna.create_study(direction='maximize')

91

study.optimize(objective, n_trials=10)

92

```

93

94

### VecNormalize Saving Callback

95

96

Callback for automatically saving VecNormalize statistics during training, ensuring normalization parameters are preserved with the model.

97

98

```python { .api }

99

class SaveVecNormalizeCallback(BaseCallback):

100

"""

101

Callback for saving VecNormalize statistics at regular intervals.

102

"""

103

104

def __init__(

105

self,

106

save_freq: int,

107

save_path: str,

108

name_prefix: str = "vecnormalize",

109

verbose: int = 0

110

):

111

"""

112

Initialize SaveVecNormalizeCallback.

113

114

Parameters:

115

- save_freq: Frequency of saving (in timesteps)

116

- save_path: Directory path for saving statistics

117

- name_prefix: Prefix for saved files

118

- verbose: Verbosity level

119

"""

120

121

def _on_step(self) -> bool:

122

"""

123

Called at each training step. Saves VecNormalize stats at specified frequency.

124

125

Returns:

126

bool: Always True (never stops training)

127

"""

128

```

129

130

Usage example:

131

```python

132

from rl_zoo3.callbacks import SaveVecNormalizeCallback

133

from stable_baselines3.common.vec_env import VecNormalize

134

from stable_baselines3 import PPO

135

136

# Create normalized environment

137

env = VecNormalize(base_env, norm_obs=True, norm_reward=True)

138

139

# Create callback

140

save_callback = SaveVecNormalizeCallback(

141

save_freq=5000,

142

save_path="./logs/",

143

name_prefix="ppo_vecnormalize",

144

verbose=1

145

)

146

147

# Train with callback

148

model = PPO("MlpPolicy", env, verbose=1)

149

model.learn(total_timesteps=20000, callback=save_callback)

150

```

151

152

### Parallel Training Callback

153

154

Callback for coordinating parallel training processes, handling synchronization and communication between multiple training instances.

155

156

```python { .api }

157

class ParallelTrainCallback(BaseCallback):

158

"""

159

Callback for parallel training coordination.

160

Handles synchronization between multiple training processes.

161

"""

162

163

def __init__(self, verbose: int = 0):

164

"""

165

Initialize ParallelTrainCallback.

166

167

Parameters:

168

- verbose: Verbosity level

169

"""

170

171

def _on_training_start(self) -> None:

172

"""

173

Called when training starts. Sets up parallel coordination.

174

"""

175

176

def _on_step(self) -> bool:

177

"""

178

Called at each training step. Handles parallel synchronization.

179

180

Returns:

181

bool: Whether training should continue

182

"""

183

184

def _on_training_end(self) -> None:

185

"""

186

Called when training ends. Cleans up parallel resources.

187

"""

188

```

189

190

### Raw Statistics Callback

191

192

Callback for logging detailed training statistics and metrics, providing comprehensive monitoring of training progress.

193

194

```python { .api }

195

class RawStatisticsCallback(BaseCallback):

196

"""

197

Callback for logging raw training statistics.

198

Provides detailed metrics beyond standard Stable Baselines3 logging.

199

"""

200

201

def __init__(

202

self,

203

verbose: int = 0,

204

log_freq: int = 1000

205

):

206

"""

207

Initialize RawStatisticsCallback.

208

209

Parameters:

210

- verbose: Verbosity level

211

- log_freq: Frequency of detailed logging (in timesteps)

212

"""

213

214

def _on_step(self) -> bool:

215

"""

216

Called at each training step. Logs detailed statistics at specified frequency.

217

218

Returns:

219

bool: Always True (never stops training)

220

"""

221

222

def _on_rollout_end(self) -> None:

223

"""

224

Called at the end of each rollout. Logs rollout statistics.

225

"""

226

```

227

228

Usage example:

229

```python

230

from rl_zoo3.callbacks import RawStatisticsCallback

231

from stable_baselines3 import PPO

232

233

# Create callback

234

stats_callback = RawStatisticsCallback(

235

verbose=1,

236

log_freq=1000

237

)

238

239

# Train with detailed statistics logging

240

model = PPO("MlpPolicy", env, verbose=1)

241

model.learn(total_timesteps=50000, callback=stats_callback)

242

```

243

244

## Callback Combinations

245

246

You can combine multiple callbacks for comprehensive training monitoring:

247

248

```python

249

from rl_zoo3.callbacks import SaveVecNormalizeCallback, RawStatisticsCallback

250

from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList

251

252

# Create multiple callbacks

253

checkpoint_callback = CheckpointCallback(

254

save_freq=10000,

255

save_path="./checkpoints/",

256

name_prefix="ppo_model"

257

)

258

259

eval_callback = EvalCallback(

260

eval_env=eval_env,

261

best_model_save_path="./best_models/",

262

log_path="./eval_logs/",

263

eval_freq=5000,

264

n_eval_episodes=10,

265

deterministic=True

266

)

267

268

save_vec_callback = SaveVecNormalizeCallback(

269

save_freq=5000,

270

save_path="./vecnormalize/",

271

verbose=1

272

)

273

274

stats_callback = RawStatisticsCallback(

275

verbose=1,

276

log_freq=1000

277

)

278

279

# Combine callbacks

280

callback_list = CallbackList([

281

checkpoint_callback,

282

eval_callback,

283

save_vec_callback,

284

stats_callback

285

])

286

287

# Train with all callbacks

288

model = PPO("MlpPolicy", env, verbose=1)

289

model.learn(total_timesteps=100000, callback=callback_list)

290

```

291

292

## Integration with ExperimentManager

293

294

The ExperimentManager automatically creates and configures appropriate callbacks based on the training setup:

295

296

```python

297

from rl_zoo3.exp_manager import ExperimentManager

298

import argparse

299

300

# ExperimentManager will automatically create callbacks based on configuration

301

args = argparse.Namespace(

302

algo='ppo',

303

env='CartPole-v1',

304

n_timesteps=50000,

305

eval_freq=5000, # Will create EvalCallback

306

save_freq=10000, # Will create CheckpointCallback

307

verbose=1

308

)

309

310

exp_manager = ExperimentManager(

311

args=args,

312

algo='ppo',

313

env_id='CartPole-v1',

314

log_folder='./logs',

315

n_timesteps=50000,

316

eval_freq=5000,

317

save_freq=10000

318

)

319

320

# Callbacks are automatically created and configured

321

model = exp_manager.setup_experiment()

322

exp_manager.learn(model) # Uses automatically created callbacks

323

```

324

325

## Custom Callback Integration

326

327

You can also create custom callbacks that work with RL Zoo3's training system:

328

329

```python

330

from stable_baselines3.common.callbacks import BaseCallback

331

332

class CustomCallback(BaseCallback):

333

"""

334

Custom callback for specific training needs.

335

"""

336

337

def __init__(self, verbose=0):

338

super().__init__(verbose)

339

self.custom_metric = 0

340

341

def _on_training_start(self) -> None:

342

print("Custom training logic started")

343

344

def _on_step(self) -> bool:

345

# Custom logic here

346

self.custom_metric += 1

347

348

# Log custom metrics

349

if self.n_calls % 1000 == 0:

350

self.logger.record("custom/metric", self.custom_metric)

351

352

return True # Continue training

353

354

# Use custom callback with ExperimentManager

355

# You would typically integrate this through hyperparams configuration

356

# or by modifying the ExperimentManager's create_callbacks method

357

```