Multi-backend deep learning framework providing a unified API for building and training neural networks across JAX, TensorFlow, PyTorch, and OpenVINO backends
—
Backend configuration utilities for managing numerical precision, data formats, random seeds, and cross-backend compatibility settings for JAX, TensorFlow, PyTorch, and OpenVINO backends.
Functions to query current backend configuration and capabilities.
def backend():
"""
Get the name of the current backend.
Returns:
str: Backend name ('jax', 'tensorflow', 'torch', or 'openvino')
"""
def list_devices(device_type=None):
"""
List available compute devices.
Args:
device_type (str, optional): Filter by device type ('cpu', 'gpu', 'tpu')
Returns:
list: Available devices
"""Settings for controlling numerical precision and floating-point behavior.
def floatx():
"""
Get the default floating-point type.
Returns:
str: Default float type ('float16', 'float32', or 'float64')
"""
def set_floatx(dtype):
"""
Set the default floating-point type.
Args:
dtype (str): Float type to use ('float16', 'float32', or 'float64')
"""
def epsilon():
"""
Get the numerical epsilon value.
Returns:
float: Small constant for numerical stability
"""
def set_epsilon(value):
"""
Set the numerical epsilon value.
Args:
value (float): Small constant for numerical stability
"""Settings for controlling data layout and format conventions.
def image_data_format():
"""
Get the default image data format.
Returns:
str: Data format ('channels_last' or 'channels_first')
"""
def set_image_data_format(data_format):
"""
Set the default image data format.
Args:
data_format (str): Format to use ('channels_last' or 'channels_first')
"""Functions for managing backend sessions and clearing state.
def clear_session():
"""
Clear backend session and free memory.
This function clears any cached state, resets default graph,
and triggers garbage collection to free up memory.
"""
def get_uid(prefix=''):
"""
Generate unique identifier for naming.
Args:
prefix (str): Prefix for the identifier
Returns:
str: Unique identifier string
"""Functions for controlling random number generation across backends.
def set_random_seed(seed):
"""
Set global random seed for reproducibility.
This sets the random seed for the current backend, NumPy,
and Python's random module to ensure reproducible results.
Args:
seed (int): Random seed value
"""Utilities for working with data types across different backends.
def is_keras_tensor(x):
"""
Check if object is a Keras tensor.
Args:
x: Object to check
Returns:
bool: True if x is a Keras tensor
"""
def is_float_dtype(dtype):
"""
Check if data type is floating point.
Args:
dtype (str or dtype): Data type to check
Returns:
bool: True if dtype is floating point
"""
def is_int_dtype(dtype):
"""
Check if data type is integer.
Args:
dtype (str or dtype): Data type to check
Returns:
bool: True if dtype is integer
"""
def standardize_dtype(dtype):
"""
Standardize data type string representation.
Args:
dtype (str or dtype): Data type to standardize
Returns:
str: Standardized dtype string
"""
def result_type(*dtypes):
"""
Determine result data type from multiple input types.
Args:
*dtypes: Input data types
Returns:
str: Result data type
"""Functions for device placement and context management.
def device(device_name):
"""
Device placement context manager.
Args:
device_name (str): Device name ('cpu', 'gpu', 'gpu:0', etc.)
Returns:
context manager: Device placement context
"""
def name_scope(name):
"""
Name scoping context manager for operations.
Args:
name (str): Scope name
Returns:
context manager: Name scope context
"""Settings for mixed precision training and inference.
# Available in keras.mixed_precision
def set_global_policy(policy):
"""
Set global mixed precision policy.
Args:
policy (str or Policy): Policy name or Policy instance
Common policies: 'mixed_float16', 'mixed_bfloat16', 'float32'
"""
def global_policy():
"""
Get current global mixed precision policy.
Returns:
Policy: Current mixed precision policy
"""import keras
from keras import backend
# Check current backend
print(f"Current backend: {backend.backend()}")
# Configure floating point precision
backend.set_floatx('float32')
print(f"Default float type: {backend.floatx()}")
# Set image data format
backend.set_image_data_format('channels_last')
print(f"Image data format: {backend.image_data_format()}")
# Set random seed for reproducibility
keras.utils.set_random_seed(42)
# Clear session to free memory
backend.clear_session()import keras
from keras import backend
# Use CPU for specific operations
with backend.device('cpu'):
x = keras.ops.ones((1000, 1000))
y = keras.ops.matmul(x, x)
# Use GPU if available
with backend.device('gpu:0'):
model = keras.Sequential([
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
predictions = model(x)import keras
from keras import mixed_precision
# Enable mixed precision
mixed_precision.set_global_policy('mixed_float16')
# Build model (will use mixed precision automatically)
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax', dtype='float32') # Keep output in float32
])
# Use LossScaleOptimizer for stable training
optimizer = keras.optimizers.Adam()
optimizer = keras.optimizers.LossScaleOptimizer(optimizer)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Train normally - mixed precision is handled automatically
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))import keras
from keras import backend
# Configuration based on backend
if backend.backend() == 'tensorflow':
# TensorFlow-specific settings
import tensorflow as tf
tf.config.experimental.enable_memory_growth = True
elif backend.backend() == 'jax':
# JAX-specific settings
import jax
jax.config.update('jax_enable_x64', True)
elif backend.backend() == 'torch':
# PyTorch-specific settings
import torch
torch.backends.cudnn.benchmark = True
# Universal settings
backend.set_floatx('float32')
backend.set_image_data_format('channels_last')
keras.utils.set_random_seed(42)import keras
from keras import backend
import gc
def train_with_memory_management(model, train_data, val_data):
"""Train model with explicit memory management."""
# Clear any existing session state
backend.clear_session()
# Train model
history = model.fit(
train_data,
validation_data=val_data,
epochs=10
)
# Clear session and force garbage collection
backend.clear_session()
gc.collect()
return history
# Usage
model = keras.Sequential([...])
history = train_with_memory_management(model, train_dataset, val_dataset)import keras
from keras import backend
import numpy as np
import random
import os
def setup_reproducible_training(seed=42):
"""Set up reproducible training environment."""
# Set random seeds
keras.utils.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
# Backend-specific reproducibility
if backend.backend() == 'tensorflow':
import tensorflow as tf
tf.config.experimental.enable_op_determinism()
# Clear any existing state
backend.clear_session()
print(f"Reproducible training setup complete with seed {seed}")
# Setup reproducible environment
setup_reproducible_training(42)
# Now build and train model
model = keras.Sequential([...])
model.compile(optimizer='adam', loss='mse')
model.fit(x_train, y_train, epochs=10)Install with Tessl CLI
npx tessl i tessl/pypi-keras-nightly