CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-keras-nightly

Multi-backend deep learning framework providing a unified API for building and training neural networks across JAX, TensorFlow, PyTorch, and OpenVINO backends

Pending
Overview
Eval results
Files

backend-config.mddocs/

Backend Configuration

Backend configuration utilities for managing numerical precision, data formats, random seeds, and cross-backend compatibility settings for JAX, TensorFlow, PyTorch, and OpenVINO backends.

Capabilities

Backend Information

Functions to query current backend configuration and capabilities.

def backend():
    """
    Get the name of the current backend.
    
    Returns:
        str: Backend name ('jax', 'tensorflow', 'torch', or 'openvino')
    """

def list_devices(device_type=None):
    """
    List available compute devices.
    
    Args:
        device_type (str, optional): Filter by device type ('cpu', 'gpu', 'tpu')
        
    Returns:
        list: Available devices
    """

Numerical Precision Configuration

Settings for controlling numerical precision and floating-point behavior.

def floatx():
    """
    Get the default floating-point type.
    
    Returns:
        str: Default float type ('float16', 'float32', or 'float64')
    """

def set_floatx(dtype):
    """
    Set the default floating-point type.
    
    Args:
        dtype (str): Float type to use ('float16', 'float32', or 'float64')
    """

def epsilon():
    """
    Get the numerical epsilon value.
    
    Returns:
        float: Small constant for numerical stability
    """

def set_epsilon(value):
    """
    Set the numerical epsilon value.
    
    Args:
        value (float): Small constant for numerical stability
    """

Data Format Configuration

Settings for controlling data layout and format conventions.

def image_data_format():
    """
    Get the default image data format.
    
    Returns:
        str: Data format ('channels_last' or 'channels_first')
    """

def set_image_data_format(data_format):
    """
    Set the default image data format.
    
    Args:
        data_format (str): Format to use ('channels_last' or 'channels_first')
    """

Session and State Management

Functions for managing backend sessions and clearing state.

def clear_session():
    """
    Clear backend session and free memory.
    
    This function clears any cached state, resets default graph,
    and triggers garbage collection to free up memory.
    """

def get_uid(prefix=''):
    """
    Generate unique identifier for naming.
    
    Args:
        prefix (str): Prefix for the identifier
        
    Returns:
        str: Unique identifier string
    """

Random Seed Configuration

Functions for controlling random number generation across backends.

def set_random_seed(seed):
    """
    Set global random seed for reproducibility.
    
    This sets the random seed for the current backend, NumPy,
    and Python's random module to ensure reproducible results.
    
    Args:
        seed (int): Random seed value
    """

Data Type Utilities

Utilities for working with data types across different backends.

def is_keras_tensor(x):
    """
    Check if object is a Keras tensor.
    
    Args:
        x: Object to check
        
    Returns:
        bool: True if x is a Keras tensor
    """

def is_float_dtype(dtype):
    """
    Check if data type is floating point.
    
    Args:
        dtype (str or dtype): Data type to check
        
    Returns:
        bool: True if dtype is floating point
    """

def is_int_dtype(dtype):
    """
    Check if data type is integer.
    
    Args:
        dtype (str or dtype): Data type to check
        
    Returns:
        bool: True if dtype is integer
    """

def standardize_dtype(dtype):
    """
    Standardize data type string representation.
    
    Args:
        dtype (str or dtype): Data type to standardize
        
    Returns:
        str: Standardized dtype string
    """

def result_type(*dtypes):
    """
    Determine result data type from multiple input types.
    
    Args:
        *dtypes: Input data types
        
    Returns:
        str: Result data type
    """

Device Management

Functions for device placement and context management.

def device(device_name):
    """
    Device placement context manager.
    
    Args:
        device_name (str): Device name ('cpu', 'gpu', 'gpu:0', etc.)
        
    Returns:
        context manager: Device placement context
    """

def name_scope(name):
    """
    Name scoping context manager for operations.
    
    Args:
        name (str): Scope name
        
    Returns:
        context manager: Name scope context
    """

Mixed Precision Configuration

Settings for mixed precision training and inference.

# Available in keras.mixed_precision
def set_global_policy(policy):
    """
    Set global mixed precision policy.
    
    Args:
        policy (str or Policy): Policy name or Policy instance
            Common policies: 'mixed_float16', 'mixed_bfloat16', 'float32'
    """

def global_policy():
    """
    Get current global mixed precision policy.
    
    Returns:
        Policy: Current mixed precision policy
    """

Usage Examples

Basic Backend Configuration

import keras
from keras import backend

# Check current backend
print(f"Current backend: {backend.backend()}")

# Configure floating point precision
backend.set_floatx('float32')
print(f"Default float type: {backend.floatx()}")

# Set image data format
backend.set_image_data_format('channels_last')
print(f"Image data format: {backend.image_data_format()}")

# Set random seed for reproducibility
keras.utils.set_random_seed(42)

# Clear session to free memory
backend.clear_session()

Device Placement

import keras
from keras import backend

# Use CPU for specific operations
with backend.device('cpu'):
    x = keras.ops.ones((1000, 1000))
    y = keras.ops.matmul(x, x)

# Use GPU if available
with backend.device('gpu:0'):
    model = keras.Sequential([
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    predictions = model(x)

Mixed Precision Training

import keras
from keras import mixed_precision

# Enable mixed precision
mixed_precision.set_global_policy('mixed_float16')

# Build model (will use mixed precision automatically)
model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    keras.layers.Dense(10, activation='softmax', dtype='float32')  # Keep output in float32
])

# Use LossScaleOptimizer for stable training
optimizer = keras.optimizers.Adam()
optimizer = keras.optimizers.LossScaleOptimizer(optimizer)

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train normally - mixed precision is handled automatically
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

Backend-Specific Configuration

import keras
from keras import backend

# Configuration based on backend
if backend.backend() == 'tensorflow':
    # TensorFlow-specific settings
    import tensorflow as tf
    tf.config.experimental.enable_memory_growth = True
    
elif backend.backend() == 'jax':
    # JAX-specific settings
    import jax
    jax.config.update('jax_enable_x64', True)
    
elif backend.backend() == 'torch':
    # PyTorch-specific settings
    import torch
    torch.backends.cudnn.benchmark = True

# Universal settings
backend.set_floatx('float32')
backend.set_image_data_format('channels_last')
keras.utils.set_random_seed(42)

Memory Management

import keras
from keras import backend
import gc

def train_with_memory_management(model, train_data, val_data):
    """Train model with explicit memory management."""
    
    # Clear any existing session state
    backend.clear_session()
    
    # Train model
    history = model.fit(
        train_data,
        validation_data=val_data,
        epochs=10
    )
    
    # Clear session and force garbage collection
    backend.clear_session()
    gc.collect()
    
    return history

# Usage
model = keras.Sequential([...])
history = train_with_memory_management(model, train_dataset, val_dataset)

Reproducible Training Setup

import keras
from keras import backend
import numpy as np
import random
import os

def setup_reproducible_training(seed=42):
    """Set up reproducible training environment."""
    
    # Set random seeds
    keras.utils.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # Backend-specific reproducibility
    if backend.backend() == 'tensorflow':
        import tensorflow as tf
        tf.config.experimental.enable_op_determinism()
    
    # Clear any existing state
    backend.clear_session()
    
    print(f"Reproducible training setup complete with seed {seed}")

# Setup reproducible environment
setup_reproducible_training(42)

# Now build and train model
model = keras.Sequential([...])
model.compile(optimizer='adam', loss='mse')
model.fit(x_train, y_train, epochs=10)

Install with Tessl CLI

npx tessl i tessl/pypi-keras-nightly

docs

activations.md

applications.md

backend-config.md

core-framework.md

index.md

initializers.md

layers.md

losses-metrics.md

operations.md

optimizers.md

preprocessing.md

regularizers.md

training-callbacks.md

tile.json