0
# Backend Configuration
1
2
Backend configuration utilities for managing numerical precision, data formats, random seeds, and cross-backend compatibility settings for JAX, TensorFlow, PyTorch, and OpenVINO backends.
3
4
## Capabilities
5
6
### Backend Information
7
8
Functions to query current backend configuration and capabilities.
9
10
```python { .api }
11
def backend():
12
"""
13
Get the name of the current backend.
14
15
Returns:
16
str: Backend name ('jax', 'tensorflow', 'torch', or 'openvino')
17
"""
18
19
def list_devices(device_type=None):
20
"""
21
List available compute devices.
22
23
Args:
24
device_type (str, optional): Filter by device type ('cpu', 'gpu', 'tpu')
25
26
Returns:
27
list: Available devices
28
"""
29
```
30
31
### Numerical Precision Configuration
32
33
Settings for controlling numerical precision and floating-point behavior.
34
35
```python { .api }
36
def floatx():
37
"""
38
Get the default floating-point type.
39
40
Returns:
41
str: Default float type ('float16', 'float32', or 'float64')
42
"""
43
44
def set_floatx(dtype):
45
"""
46
Set the default floating-point type.
47
48
Args:
49
dtype (str): Float type to use ('float16', 'float32', or 'float64')
50
"""
51
52
def epsilon():
53
"""
54
Get the numerical epsilon value.
55
56
Returns:
57
float: Small constant for numerical stability
58
"""
59
60
def set_epsilon(value):
61
"""
62
Set the numerical epsilon value.
63
64
Args:
65
value (float): Small constant for numerical stability
66
"""
67
```
68
69
### Data Format Configuration
70
71
Settings for controlling data layout and format conventions.
72
73
```python { .api }
74
def image_data_format():
75
"""
76
Get the default image data format.
77
78
Returns:
79
str: Data format ('channels_last' or 'channels_first')
80
"""
81
82
def set_image_data_format(data_format):
83
"""
84
Set the default image data format.
85
86
Args:
87
data_format (str): Format to use ('channels_last' or 'channels_first')
88
"""
89
```
90
91
### Session and State Management
92
93
Functions for managing backend sessions and clearing state.
94
95
```python { .api }
96
def clear_session():
97
"""
98
Clear backend session and free memory.
99
100
This function clears any cached state, resets default graph,
101
and triggers garbage collection to free up memory.
102
"""
103
104
def get_uid(prefix=''):
105
"""
106
Generate unique identifier for naming.
107
108
Args:
109
prefix (str): Prefix for the identifier
110
111
Returns:
112
str: Unique identifier string
113
"""
114
```
115
116
### Random Seed Configuration
117
118
Functions for controlling random number generation across backends.
119
120
```python { .api }
121
def set_random_seed(seed):
122
"""
123
Set global random seed for reproducibility.
124
125
This sets the random seed for the current backend, NumPy,
126
and Python's random module to ensure reproducible results.
127
128
Args:
129
seed (int): Random seed value
130
"""
131
```
132
133
### Data Type Utilities
134
135
Utilities for working with data types across different backends.
136
137
```python { .api }
138
def is_keras_tensor(x):
139
"""
140
Check if object is a Keras tensor.
141
142
Args:
143
x: Object to check
144
145
Returns:
146
bool: True if x is a Keras tensor
147
"""
148
149
def is_float_dtype(dtype):
150
"""
151
Check if data type is floating point.
152
153
Args:
154
dtype (str or dtype): Data type to check
155
156
Returns:
157
bool: True if dtype is floating point
158
"""
159
160
def is_int_dtype(dtype):
161
"""
162
Check if data type is integer.
163
164
Args:
165
dtype (str or dtype): Data type to check
166
167
Returns:
168
bool: True if dtype is integer
169
"""
170
171
def standardize_dtype(dtype):
172
"""
173
Standardize data type string representation.
174
175
Args:
176
dtype (str or dtype): Data type to standardize
177
178
Returns:
179
str: Standardized dtype string
180
"""
181
182
def result_type(*dtypes):
183
"""
184
Determine result data type from multiple input types.
185
186
Args:
187
*dtypes: Input data types
188
189
Returns:
190
str: Result data type
191
"""
192
```
193
194
### Device Management
195
196
Functions for device placement and context management.
197
198
```python { .api }
199
def device(device_name):
200
"""
201
Device placement context manager.
202
203
Args:
204
device_name (str): Device name ('cpu', 'gpu', 'gpu:0', etc.)
205
206
Returns:
207
context manager: Device placement context
208
"""
209
210
def name_scope(name):
211
"""
212
Name scoping context manager for operations.
213
214
Args:
215
name (str): Scope name
216
217
Returns:
218
context manager: Name scope context
219
"""
220
```
221
222
### Mixed Precision Configuration
223
224
Settings for mixed precision training and inference.
225
226
```python { .api }
227
# Available in keras.mixed_precision
228
def set_global_policy(policy):
229
"""
230
Set global mixed precision policy.
231
232
Args:
233
policy (str or Policy): Policy name or Policy instance
234
Common policies: 'mixed_float16', 'mixed_bfloat16', 'float32'
235
"""
236
237
def global_policy():
238
"""
239
Get current global mixed precision policy.
240
241
Returns:
242
Policy: Current mixed precision policy
243
"""
244
```
245
246
## Usage Examples
247
248
### Basic Backend Configuration
249
250
```python
251
import keras
252
from keras import backend
253
254
# Check current backend
255
print(f"Current backend: {backend.backend()}")
256
257
# Configure floating point precision
258
backend.set_floatx('float32')
259
print(f"Default float type: {backend.floatx()}")
260
261
# Set image data format
262
backend.set_image_data_format('channels_last')
263
print(f"Image data format: {backend.image_data_format()}")
264
265
# Set random seed for reproducibility
266
keras.utils.set_random_seed(42)
267
268
# Clear session to free memory
269
backend.clear_session()
270
```
271
272
### Device Placement
273
274
```python
275
import keras
276
from keras import backend
277
278
# Use CPU for specific operations
279
with backend.device('cpu'):
280
x = keras.ops.ones((1000, 1000))
281
y = keras.ops.matmul(x, x)
282
283
# Use GPU if available
284
with backend.device('gpu:0'):
285
model = keras.Sequential([
286
keras.layers.Dense(64, activation='relu'),
287
keras.layers.Dense(10, activation='softmax')
288
])
289
290
predictions = model(x)
291
```
292
293
### Mixed Precision Training
294
295
```python
296
import keras
297
from keras import mixed_precision
298
299
# Enable mixed precision
300
mixed_precision.set_global_policy('mixed_float16')
301
302
# Build model (will use mixed precision automatically)
303
model = keras.Sequential([
304
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
305
keras.layers.Dense(10, activation='softmax', dtype='float32') # Keep output in float32
306
])
307
308
# Use LossScaleOptimizer for stable training
309
optimizer = keras.optimizers.Adam()
310
optimizer = keras.optimizers.LossScaleOptimizer(optimizer)
311
312
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
313
314
# Train normally - mixed precision is handled automatically
315
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))
316
```
317
318
### Backend-Specific Configuration
319
320
```python
321
import keras
322
from keras import backend
323
324
# Configuration based on backend
325
if backend.backend() == 'tensorflow':
326
# TensorFlow-specific settings
327
import tensorflow as tf
328
tf.config.experimental.enable_memory_growth = True
329
330
elif backend.backend() == 'jax':
331
# JAX-specific settings
332
import jax
333
jax.config.update('jax_enable_x64', True)
334
335
elif backend.backend() == 'torch':
336
# PyTorch-specific settings
337
import torch
338
torch.backends.cudnn.benchmark = True
339
340
# Universal settings
341
backend.set_floatx('float32')
342
backend.set_image_data_format('channels_last')
343
keras.utils.set_random_seed(42)
344
```
345
346
### Memory Management
347
348
```python
349
import keras
350
from keras import backend
351
import gc
352
353
def train_with_memory_management(model, train_data, val_data):
354
"""Train model with explicit memory management."""
355
356
# Clear any existing session state
357
backend.clear_session()
358
359
# Train model
360
history = model.fit(
361
train_data,
362
validation_data=val_data,
363
epochs=10
364
)
365
366
# Clear session and force garbage collection
367
backend.clear_session()
368
gc.collect()
369
370
return history
371
372
# Usage
373
model = keras.Sequential([...])
374
history = train_with_memory_management(model, train_dataset, val_dataset)
375
```
376
377
### Reproducible Training Setup
378
379
```python
380
import keras
381
from keras import backend
382
import numpy as np
383
import random
384
import os
385
386
def setup_reproducible_training(seed=42):
387
"""Set up reproducible training environment."""
388
389
# Set random seeds
390
keras.utils.set_random_seed(seed)
391
np.random.seed(seed)
392
random.seed(seed)
393
os.environ['PYTHONHASHSEED'] = str(seed)
394
395
# Backend-specific reproducibility
396
if backend.backend() == 'tensorflow':
397
import tensorflow as tf
398
tf.config.experimental.enable_op_determinism()
399
400
# Clear any existing state
401
backend.clear_session()
402
403
print(f"Reproducible training setup complete with seed {seed}")
404
405
# Setup reproducible environment
406
setup_reproducible_training(42)
407
408
# Now build and train model
409
model = keras.Sequential([...])
410
model.compile(optimizer='adam', loss='mse')
411
model.fit(x_train, y_train, epochs=10)
412
```