or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

activations.mdapplications.mdbackend-config.mdcore-framework.mdindex.mdinitializers.mdlayers.mdlosses-metrics.mdoperations.mdoptimizers.mdpreprocessing.mdregularizers.mdtraining-callbacks.md

backend-config.mddocs/

0

# Backend Configuration

1

2

Backend configuration utilities for managing numerical precision, data formats, random seeds, and cross-backend compatibility settings for JAX, TensorFlow, PyTorch, and OpenVINO backends.

3

4

## Capabilities

5

6

### Backend Information

7

8

Functions to query current backend configuration and capabilities.

9

10

```python { .api }

11

def backend():

12

"""

13

Get the name of the current backend.

14

15

Returns:

16

str: Backend name ('jax', 'tensorflow', 'torch', or 'openvino')

17

"""

18

19

def list_devices(device_type=None):

20

"""

21

List available compute devices.

22

23

Args:

24

device_type (str, optional): Filter by device type ('cpu', 'gpu', 'tpu')

25

26

Returns:

27

list: Available devices

28

"""

29

```

30

31

### Numerical Precision Configuration

32

33

Settings for controlling numerical precision and floating-point behavior.

34

35

```python { .api }

36

def floatx():

37

"""

38

Get the default floating-point type.

39

40

Returns:

41

str: Default float type ('float16', 'float32', or 'float64')

42

"""

43

44

def set_floatx(dtype):

45

"""

46

Set the default floating-point type.

47

48

Args:

49

dtype (str): Float type to use ('float16', 'float32', or 'float64')

50

"""

51

52

def epsilon():

53

"""

54

Get the numerical epsilon value.

55

56

Returns:

57

float: Small constant for numerical stability

58

"""

59

60

def set_epsilon(value):

61

"""

62

Set the numerical epsilon value.

63

64

Args:

65

value (float): Small constant for numerical stability

66

"""

67

```

68

69

### Data Format Configuration

70

71

Settings for controlling data layout and format conventions.

72

73

```python { .api }

74

def image_data_format():

75

"""

76

Get the default image data format.

77

78

Returns:

79

str: Data format ('channels_last' or 'channels_first')

80

"""

81

82

def set_image_data_format(data_format):

83

"""

84

Set the default image data format.

85

86

Args:

87

data_format (str): Format to use ('channels_last' or 'channels_first')

88

"""

89

```

90

91

### Session and State Management

92

93

Functions for managing backend sessions and clearing state.

94

95

```python { .api }

96

def clear_session():

97

"""

98

Clear backend session and free memory.

99

100

This function clears any cached state, resets default graph,

101

and triggers garbage collection to free up memory.

102

"""

103

104

def get_uid(prefix=''):

105

"""

106

Generate unique identifier for naming.

107

108

Args:

109

prefix (str): Prefix for the identifier

110

111

Returns:

112

str: Unique identifier string

113

"""

114

```

115

116

### Random Seed Configuration

117

118

Functions for controlling random number generation across backends.

119

120

```python { .api }

121

def set_random_seed(seed):

122

"""

123

Set global random seed for reproducibility.

124

125

This sets the random seed for the current backend, NumPy,

126

and Python's random module to ensure reproducible results.

127

128

Args:

129

seed (int): Random seed value

130

"""

131

```

132

133

### Data Type Utilities

134

135

Utilities for working with data types across different backends.

136

137

```python { .api }

138

def is_keras_tensor(x):

139

"""

140

Check if object is a Keras tensor.

141

142

Args:

143

x: Object to check

144

145

Returns:

146

bool: True if x is a Keras tensor

147

"""

148

149

def is_float_dtype(dtype):

150

"""

151

Check if data type is floating point.

152

153

Args:

154

dtype (str or dtype): Data type to check

155

156

Returns:

157

bool: True if dtype is floating point

158

"""

159

160

def is_int_dtype(dtype):

161

"""

162

Check if data type is integer.

163

164

Args:

165

dtype (str or dtype): Data type to check

166

167

Returns:

168

bool: True if dtype is integer

169

"""

170

171

def standardize_dtype(dtype):

172

"""

173

Standardize data type string representation.

174

175

Args:

176

dtype (str or dtype): Data type to standardize

177

178

Returns:

179

str: Standardized dtype string

180

"""

181

182

def result_type(*dtypes):

183

"""

184

Determine result data type from multiple input types.

185

186

Args:

187

*dtypes: Input data types

188

189

Returns:

190

str: Result data type

191

"""

192

```

193

194

### Device Management

195

196

Functions for device placement and context management.

197

198

```python { .api }

199

def device(device_name):

200

"""

201

Device placement context manager.

202

203

Args:

204

device_name (str): Device name ('cpu', 'gpu', 'gpu:0', etc.)

205

206

Returns:

207

context manager: Device placement context

208

"""

209

210

def name_scope(name):

211

"""

212

Name scoping context manager for operations.

213

214

Args:

215

name (str): Scope name

216

217

Returns:

218

context manager: Name scope context

219

"""

220

```

221

222

### Mixed Precision Configuration

223

224

Settings for mixed precision training and inference.

225

226

```python { .api }

227

# Available in keras.mixed_precision

228

def set_global_policy(policy):

229

"""

230

Set global mixed precision policy.

231

232

Args:

233

policy (str or Policy): Policy name or Policy instance

234

Common policies: 'mixed_float16', 'mixed_bfloat16', 'float32'

235

"""

236

237

def global_policy():

238

"""

239

Get current global mixed precision policy.

240

241

Returns:

242

Policy: Current mixed precision policy

243

"""

244

```

245

246

## Usage Examples

247

248

### Basic Backend Configuration

249

250

```python

251

import keras

252

from keras import backend

253

254

# Check current backend

255

print(f"Current backend: {backend.backend()}")

256

257

# Configure floating point precision

258

backend.set_floatx('float32')

259

print(f"Default float type: {backend.floatx()}")

260

261

# Set image data format

262

backend.set_image_data_format('channels_last')

263

print(f"Image data format: {backend.image_data_format()}")

264

265

# Set random seed for reproducibility

266

keras.utils.set_random_seed(42)

267

268

# Clear session to free memory

269

backend.clear_session()

270

```

271

272

### Device Placement

273

274

```python

275

import keras

276

from keras import backend

277

278

# Use CPU for specific operations

279

with backend.device('cpu'):

280

x = keras.ops.ones((1000, 1000))

281

y = keras.ops.matmul(x, x)

282

283

# Use GPU if available

284

with backend.device('gpu:0'):

285

model = keras.Sequential([

286

keras.layers.Dense(64, activation='relu'),

287

keras.layers.Dense(10, activation='softmax')

288

])

289

290

predictions = model(x)

291

```

292

293

### Mixed Precision Training

294

295

```python

296

import keras

297

from keras import mixed_precision

298

299

# Enable mixed precision

300

mixed_precision.set_global_policy('mixed_float16')

301

302

# Build model (will use mixed precision automatically)

303

model = keras.Sequential([

304

keras.layers.Dense(64, activation='relu', input_shape=(784,)),

305

keras.layers.Dense(10, activation='softmax', dtype='float32') # Keep output in float32

306

])

307

308

# Use LossScaleOptimizer for stable training

309

optimizer = keras.optimizers.Adam()

310

optimizer = keras.optimizers.LossScaleOptimizer(optimizer)

311

312

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

313

314

# Train normally - mixed precision is handled automatically

315

model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

316

```

317

318

### Backend-Specific Configuration

319

320

```python

321

import keras

322

from keras import backend

323

324

# Configuration based on backend

325

if backend.backend() == 'tensorflow':

326

# TensorFlow-specific settings

327

import tensorflow as tf

328

tf.config.experimental.enable_memory_growth = True

329

330

elif backend.backend() == 'jax':

331

# JAX-specific settings

332

import jax

333

jax.config.update('jax_enable_x64', True)

334

335

elif backend.backend() == 'torch':

336

# PyTorch-specific settings

337

import torch

338

torch.backends.cudnn.benchmark = True

339

340

# Universal settings

341

backend.set_floatx('float32')

342

backend.set_image_data_format('channels_last')

343

keras.utils.set_random_seed(42)

344

```

345

346

### Memory Management

347

348

```python

349

import keras

350

from keras import backend

351

import gc

352

353

def train_with_memory_management(model, train_data, val_data):

354

"""Train model with explicit memory management."""

355

356

# Clear any existing session state

357

backend.clear_session()

358

359

# Train model

360

history = model.fit(

361

train_data,

362

validation_data=val_data,

363

epochs=10

364

)

365

366

# Clear session and force garbage collection

367

backend.clear_session()

368

gc.collect()

369

370

return history

371

372

# Usage

373

model = keras.Sequential([...])

374

history = train_with_memory_management(model, train_dataset, val_dataset)

375

```

376

377

### Reproducible Training Setup

378

379

```python

380

import keras

381

from keras import backend

382

import numpy as np

383

import random

384

import os

385

386

def setup_reproducible_training(seed=42):

387

"""Set up reproducible training environment."""

388

389

# Set random seeds

390

keras.utils.set_random_seed(seed)

391

np.random.seed(seed)

392

random.seed(seed)

393

os.environ['PYTHONHASHSEED'] = str(seed)

394

395

# Backend-specific reproducibility

396

if backend.backend() == 'tensorflow':

397

import tensorflow as tf

398

tf.config.experimental.enable_op_determinism()

399

400

# Clear any existing state

401

backend.clear_session()

402

403

print(f"Reproducible training setup complete with seed {seed}")

404

405

# Setup reproducible environment

406

setup_reproducible_training(42)

407

408

# Now build and train model

409

model = keras.Sequential([...])

410

model.compile(optimizer='adam', loss='mse')

411

model.fit(x_train, y_train, epochs=10)

412

```