or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcore-utilities.mdexperiment-management.mdhub-integration.mdhyperparameter-optimization.mdindex.mdplotting.mdwrappers.md

core-utilities.mddocs/

0

# Core Utilities

1

2

Essential utilities for working with RL environments, models, and hyperparameters. These functions form the foundation of the RL Zoo3 framework, providing core capabilities for algorithm selection, environment creation, model management, and parameter scheduling.

3

4

## Core Imports

5

6

```python

7

from rl_zoo3 import ALGOS, create_test_env, get_trained_models, linear_schedule

8

from rl_zoo3.utils import (

9

get_model_path,

10

get_saved_hyperparams,

11

get_latest_run_id,

12

get_wrapper_class,

13

get_class_by_name,

14

flatten_dict_observations,

15

get_callback_list

16

)

17

```

18

19

## Capabilities

20

21

### Algorithm Dictionary

22

23

Central registry mapping algorithm names to their Stable Baselines3 classes, enabling dynamic algorithm selection and instantiation.

24

25

```python { .api }

26

ALGOS: dict[str, type[BaseAlgorithm]]

27

```

28

29

The ALGOS dictionary contains mappings for:

30

- `"a2c"`: A2C (Advantage Actor-Critic)

31

- `"ddpg"`: DDPG (Deep Deterministic Policy Gradient)

32

- `"dqn"`: DQN (Deep Q-Network)

33

- `"ppo"`: PPO (Proximal Policy Optimization)

34

- `"sac"`: SAC (Soft Actor-Critic)

35

- `"td3"`: TD3 (Twin Delayed Deep Deterministic Policy Gradient)

36

- `"ars"`: ARS (Augmented Random Search)

37

- `"crossq"`: CrossQ

38

- `"qrdqn"`: QRDQN (Quantile Regression DQN)

39

- `"tqc"`: TQC (Truncated Quantile Critics)

40

- `"trpo"`: TRPO (Trust Region Policy Optimization)

41

- `"ppo_lstm"`: RecurrentPPO (PPO with LSTM)

42

43

Usage example:

44

```python

45

from rl_zoo3 import ALGOS

46

from stable_baselines3.common.env_util import make_vec_env

47

48

# Get the PPO algorithm class

49

ppo_class = ALGOS["ppo"]

50

51

# Create environment and model

52

env = make_vec_env("CartPole-v1", n_envs=1)

53

model = ppo_class("MlpPolicy", env, verbose=1)

54

model.learn(total_timesteps=10000)

55

```

56

57

### Environment Creation

58

59

Creates vectorized test environments with proper wrappers, normalization, and configuration for evaluation and testing.

60

61

```python { .api }

62

def create_test_env(

63

env_id: str,

64

n_envs: int = 1,

65

stats_path: Optional[str] = None,

66

seed: int = 0,

67

log_dir: Optional[str] = None,

68

should_render: bool = True,

69

hyperparams: Optional[dict[str, Any]] = None,

70

env_kwargs: Optional[dict[str, Any]] = None,

71

vec_env_cls: Optional[type[VecEnv]] = None,

72

vec_env_kwargs: Optional[dict[str, Any]] = None

73

) -> VecEnv:

74

"""

75

Create a wrapped, monitored VecEnv for testing.

76

77

Parameters:

78

- env_id: Environment identifier (e.g., 'CartPole-v1')

79

- n_envs: Number of parallel environments

80

- stats_path: Path to VecNormalize statistics file

81

- seed: Random seed for reproducibility

82

- log_dir: Directory for logging environment interactions

83

- should_render: Whether to enable rendering (for PyBullet envs)

84

- hyperparams: Hyperparameters dict containing env_wrapper settings

85

- env_kwargs: Additional keyword arguments for environment creation

86

- vec_env_cls: VecEnv class constructor to use

87

- vec_env_kwargs: Keyword arguments for VecEnv constructor

88

89

Returns:

90

VecEnv: Configured vectorized environment ready for testing

91

"""

92

```

93

94

Usage example:

95

```python

96

from rl_zoo3 import create_test_env

97

98

# Create test environment

99

env = create_test_env(

100

env_id="CartPole-v1",

101

n_envs=4,

102

seed=42,

103

should_render=False

104

)

105

106

# Use with trained model

107

obs = env.reset()

108

for _ in range(1000):

109

action = env.action_space.sample() # Random actions for demo

110

obs, rewards, dones, infos = env.step(action)

111

```

112

113

### Model Discovery

114

115

Discovers and lists trained models from log directories, returning model paths and metadata for loading and evaluation.

116

117

```python { .api }

118

def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]:

119

"""

120

Get a dictionary of trained models from a log folder.

121

122

Parameters:

123

- log_folder: Path to the directory containing trained models

124

125

Returns:

126

dict: Dictionary mapping model names to (model_path, stats_path) tuples

127

"""

128

```

129

130

```python { .api }

131

def get_hf_trained_models(

132

organization: str = "sb3",

133

check_filename: bool = False

134

) -> dict[str, tuple[str, str]]:

135

"""

136

Get trained models from HuggingFace Hub.

137

138

Parameters:

139

- organization: HuggingFace organization name

140

- check_filename: Whether to validate model filenames

141

142

Returns:

143

dict: Dictionary mapping model names to (repo_id, filename) tuples

144

"""

145

```

146

147

Usage example:

148

```python

149

from rl_zoo3 import get_trained_models, get_hf_trained_models

150

151

# Get locally trained models

152

local_models = get_trained_models("./logs")

153

print("Local models:", local_models)

154

155

# Get models from HuggingFace Hub

156

hf_models = get_hf_trained_models(organization="sb3")

157

print("HF models:", list(hf_models.keys())[:5]) # Show first 5

158

```

159

160

### Run Management

161

162

Utilities for managing training runs and finding the latest run ID for continued training or evaluation.

163

164

```python { .api }

165

def get_latest_run_id(log_path: str, env_name: str) -> int:

166

"""

167

Get the latest run ID for a given environment.

168

169

Parameters:

170

- log_path: Path to the log directory

171

- env_name: Environment name

172

173

Returns:

174

int: Latest run ID (0-indexed)

175

"""

176

```

177

178

```python { .api }

179

def get_model_path(

180

exp_id: int,

181

folder: str,

182

algo: str,

183

env_name: str,

184

load_best: bool = False,

185

load_checkpoint: Optional[str] = None,

186

load_last_checkpoint: bool = False

187

) -> tuple[str, str, str]:

188

"""

189

Get the path to a trained model and related information.

190

191

Parameters:

192

- exp_id: Experiment ID (0 for latest)

193

- folder: Log folder path

194

- algo: Algorithm name

195

- env_name: Environment name

196

- load_best: Whether to load the best model

197

- load_checkpoint: Specific checkpoint to load (e.g., "100000")

198

- load_last_checkpoint: Whether to load the last checkpoint

199

200

Returns:

201

tuple[str, str, str]: (name_prefix, model_path, log_path)

202

"""

203

```

204

205

### Hyperparameter Management

206

207

Loading and managing saved hyperparameters and normalization statistics from trained models.

208

209

```python { .api }

210

def get_saved_hyperparams(

211

stats_path: str,

212

norm_reward: bool = False,

213

test_mode: bool = False

214

) -> tuple[dict[str, Any], str]:

215

"""

216

Load saved hyperparameters from a stats file.

217

218

Parameters:

219

- stats_path: Path to the stats.pkl file

220

- norm_reward: Whether reward normalization was used

221

- test_mode: Whether in test mode

222

223

Returns:

224

tuple: (hyperparams_dict, stats_path)

225

"""

226

```

227

228

Usage example:

229

```python

230

from rl_zoo3 import get_saved_hyperparams

231

232

# Load hyperparameters from trained model

233

hyperparams, stats_path = get_saved_hyperparams("./logs/ppo/CartPole-v1_1/")

234

print("Loaded hyperparams:", hyperparams)

235

```

236

237

### Parameter Scheduling

238

239

Linear scheduling functions for hyperparameters like learning rate that need to change during training.

240

241

```python { .api }

242

def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule:

243

"""

244

Create a linear schedule for a hyperparameter.

245

246

Parameters:

247

- initial_value: Initial value (float) or string representation

248

249

Returns:

250

SimpleLinearSchedule: Callable schedule object

251

"""

252

```

253

254

```python { .api }

255

class SimpleLinearSchedule:

256

"""

257

Linear parameter scheduling class.

258

"""

259

def __init__(self, initial_value: float): ...

260

def __call__(self, progress_remaining: float) -> float: ...

261

```

262

263

Usage example:

264

```python

265

from rl_zoo3 import linear_schedule, ALGOS

266

from stable_baselines3.common.env_util import make_vec_env

267

268

# Create linear learning rate schedule

269

lr_schedule = linear_schedule(0.001)

270

271

# Use with PPO

272

env = make_vec_env("CartPole-v1", n_envs=1)

273

model = ALGOS["ppo"](

274

"MlpPolicy",

275

env,

276

learning_rate=lr_schedule,

277

verbose=1

278

)

279

```

280

281

### Environment Wrapper Utilities

282

283

Utilities for extracting and applying environment wrappers from hyperparameter configurations.

284

285

```python { .api }

286

def get_wrapper_class(

287

hyperparams: dict[str, Any],

288

key: str = "env_wrapper"

289

) -> Optional[Callable[[gym.Env], gym.Env]]:

290

"""

291

Get one or more Gym environment wrapper class from hyperparams.

292

293

Parameters:

294

- hyperparams: Hyperparameters dictionary

295

- key: Key in hyperparams containing wrapper specification

296

297

Returns:

298

Optional wrapper class or wrapper chain function

299

"""

300

```

301

302

```python { .api }

303

def get_class_by_name(name: str) -> type:

304

"""

305

Dynamically import a class by its name.

306

307

Parameters:

308

- name: Full class name (e.g., 'stable_baselines3.PPO')

309

310

Returns:

311

type: The imported class

312

"""

313

```

314

315

```python { .api }

316

def flatten_dict_observations(env: gym.Env) -> gym.Env:

317

"""

318

Flatten dictionary observation spaces.

319

320

Parameters:

321

- env: Environment with Dict observation space

322

323

Returns:

324

gym.Env: Environment with flattened observation space

325

"""

326

```

327

328

### Callback Management

329

330

Utilities for creating and managing training callbacks from hyperparameter configurations.

331

332

```python { .api }

333

def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:

334

"""

335

Get callback list from hyperparams.

336

337

Parameters:

338

- hyperparams: Hyperparameters dictionary containing callback specifications

339

340

Returns:

341

list[BaseCallback]: List of configured callbacks

342

"""

343

```

344

345

Usage example:

346

```python

347

from rl_zoo3.utils import get_callback_list

348

349

# Hyperparams with callbacks

350

hyperparams = {

351

"callback": "stable_baselines3.common.callbacks.CheckpointCallback",

352

"callback_kwargs": {"save_freq": 1000, "save_path": "./checkpoints/"}

353

}

354

355

# Get callback list

356

callbacks = get_callback_list(hyperparams)

357

print(f"Created {len(callbacks)} callbacks")

358

```

359

360

### Utility Classes

361

362

```python { .api }

363

class StoreDict(argparse.Action):

364

"""

365

Argparse action for storing dictionary parameters.

366

Converts key=value pairs to dictionary entries.

367

"""

368

def __call__(self, parser, namespace, values, option_string=None): ...

369

```

370

371

Usage example:

372

```python

373

import argparse

374

from rl_zoo3.utils import StoreDict

375

376

parser = argparse.ArgumentParser()

377

parser.add_argument("--env-kwargs", type=str, nargs="+", action=StoreDict)

378

args = parser.parse_args(["--env-kwargs", "render_mode=human", "max_steps=1000"])

379

print(args.env_kwargs) # {'render_mode': 'human', 'max_steps': '1000'}

380

```