CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

device-memory.mddocs/

Device and Memory Management

JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.

Core Imports

import jax
from jax import devices, device_put, make_mesh
from jax.sharding import NamedSharding, PartitionSpec as P

Capabilities

Device Discovery and Information

Query available devices and their properties for computation placement and resource management.

def devices(backend=None) -> list[Device]:
    """
    Get list of all available devices.
    
    Args:
        backend: Optional backend name ('cpu', 'gpu', 'tpu')
        
    Returns:
        List of available Device objects
    """

def local_devices(process_index=None, backend=None) -> list[Device]:
    """
    Get list of devices local to current process.
    
    Args:
        process_index: Process index (None for current process)
        backend: Optional backend name
        
    Returns:
        List of local Device objects
    """

def device_count(backend=None) -> int:
    """
    Get total number of devices across all processes.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Total device count
    """

def local_device_count(backend=None) -> int:
    """
    Get number of devices on current process.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Local device count
    """

def host_count(backend=None) -> int:
    """
    Get number of hosts in distributed computation.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Host count
    """

def host_id(backend=None) -> int:
    """
    Get ID of current host.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Current host ID
    """

def host_ids(backend=None) -> list[int]:
    """
    Get list of all host IDs.
    
    Args:
        backend: Optional backend name
        
    Returns:
        List of host IDs
    """

def process_count(backend=None) -> int:
    """
    Get number of processes in distributed computation.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Process count
    """

def process_index(backend=None) -> int:
    """
    Get index of current process.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Current process index
    """

def process_indices(backend=None) -> list[int]:
    """
    Get list of all process indices.
    
    Args:
        backend: Optional backend name
        
    Returns:
        List of process indices
    """

def default_backend() -> str:
    """
    Get name of default backend.
    
    Returns:
        Default backend name string
    """

Device Placement and Data Movement

Control where computations run and move data between devices and host memory.

def device_put(x, device=None, src=None) -> Array:
    """
    Move array to specified device.
    
    Args:
        x: Array or array-like object to move
        device: Target device (None for default device)
        src: Source device for the transfer
        
    Returns:
        Array placed on target device
    """

def device_put_sharded(
    sharded_values: list, 
    devices: list[Device],
    indices=None
) -> Array:
    """
    Create sharded array from per-device values.
    
    Args:
        sharded_values: List of arrays, one per device
        devices: List of target devices
        indices: Optional sharding indices
        
    Returns:
        Distributed array sharded across devices
    """

def device_put_replicated(x, devices: list[Device]) -> Array:
    """
    Replicate array across multiple devices.
    
    Args:
        x: Array to replicate
        devices: List of target devices
        
    Returns:
        Array replicated across all specified devices
    """

def device_get(x) -> Any:
    """
    Move array from device to host memory as NumPy array.
    
    Args:
        x: Array to move to host
        
    Returns:
        NumPy array in host memory
    """

def copy_to_host_async(x) -> Any:
    """
    Asynchronously copy array to host memory.
    
    Args:
        x: Array to copy
        
    Returns:
        Future-like object for async copy
    """

def block_until_ready(x) -> Array:
    """
    Block until array computation is complete and ready.
    
    Args:
        x: Array to wait for
        
    Returns:
        The same array, guaranteed to be ready
    """

Usage examples:

# Check available devices
all_devices = jax.devices()
print(f"Available devices: {all_devices}")
print(f"Device count: {jax.device_count()}")

# Move data to specific device
cpu_data = jnp.array([1, 2, 3, 4])
if jax.devices('gpu'):
    gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])
    print(f"Data is on: {gpu_data.device()}")

# Move back to host
host_data = jax.device_get(gpu_data)  # Returns NumPy array

# Explicit device placement in computations
with jax.default_device(jax.devices('cpu')[0]):
    cpu_result = jnp.sum(jnp.array([1, 2, 3]))

Sharding and Distributed Arrays

Define how arrays are distributed across multiple devices for parallel computation.

class NamedSharding:
    """
    Sharding specification using named mesh axes.
    
    Defines how arrays are partitioned across devices using logical axis names.
    """
    
    def __init__(self, mesh, spec):
        """
        Create named sharding specification.
        
        Args:
            mesh: Device mesh with named axes
            spec: Partition specification (PartitionSpec)
        """
        self.mesh = mesh
        self.spec = spec

class PartitionSpec:
    """
    Specification for how to partition array dimensions across mesh axes.
    
    Use P(axis_names...) to create partition specifications.
    """
    pass

# Alias for PartitionSpec  
P = PartitionSpec

def make_mesh(mesh_shape, axis_names) -> Mesh:
    """
    Create device mesh for distributed computation.
    
    Args:
        mesh_shape: Shape of device mesh (tuple of integers)
        axis_names: Names for mesh axes (tuple of strings)
        
    Returns:
        Mesh object representing device layout
    """

class Mesh:
    """Device mesh for distributed computation."""
    devices: Array  # Device array in mesh shape
    axis_names: tuple[str, ...]  # Names of mesh axes
    
    @property
    def shape(self) -> dict[str, int]:
        """Dictionary mapping axis names to sizes."""
        
    @property 
    def size(self) -> int:
        """Total number of devices in mesh."""

def make_array_from_single_device_arrays(
    arrays: list[Array],
    sharding: Sharding
) -> Array:
    """
    Create distributed array from per-device arrays.
    
    Args:
        arrays: List of arrays on different devices
        sharding: Sharding specification
        
    Returns:
        Distributed array with specified sharding
    """

def make_array_from_callback(
    shape: tuple[int, ...],
    sharding: Sharding, 
    data_callback: Callable
) -> Array:
    """
    Create distributed array using callback function.
    
    Args:
        shape: Global array shape
        sharding: Sharding specification  
        data_callback: Function to generate data for each shard
        
    Returns:
        Distributed array created from callback
    """

def make_array_from_process_local_data(
    sharding: Sharding,
    local_data: Array
) -> Array:
    """
    Create distributed array from process-local data.
    
    Args:
        sharding: Sharding specification
        local_data: Data local to current process
        
    Returns:
        Distributed array assembled from local data
    """

Sharded Computation

Execute computations on sharded arrays with explicit control over parallelization.

def shard_map(
    f: Callable,
    mesh: Mesh,
    in_specs,
    out_specs,
    check_rep=True
) -> Callable:
    """
    Transform function to operate on sharded arrays.
    
    Args:
        f: Function to transform
        mesh: Device mesh for computation
        in_specs: Input sharding specifications
        out_specs: Output sharding specifications  
        check_rep: Whether to check for replication consistency
        
    Returns:
        Function that operates on globally sharded arrays
    """

# Alias for shard_map
smap = shard_map

def with_sharding_constraint(x, sharding) -> Array:
    """
    Add sharding constraint to array.
    
    Args:
        x: Input array
        sharding: Desired sharding specification
        
    Returns:
        Array with sharding constraint applied
    """

Usage examples:

# Create 2x2 device mesh
devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.make_mesh((2, 2), ('data', 'model'))

# Define sharding specifications
data_sharding = NamedSharding(mesh, P('data', None))  # Shard first axis across 'data'
model_sharding = NamedSharding(mesh, P(None, 'model'))  # Shard second axis across 'model'
replicated_sharding = NamedSharding(mesh, P())  # Replicated across all devices

# Create sharded arrays
x = jax.random.normal(jax.random.key(0), (8, 4))
x_sharded = jax.device_put(x, data_sharding)

weights = jax.random.normal(jax.random.key(1), (4, 8))
weights_sharded = jax.device_put(weights, model_sharding)

# Computation with sharded arrays automatically parallelized  
@jax.jit
def matmul_fn(x, w):
    return x @ w

result = matmul_fn(x_sharded, weights_sharded)  # Automatically sharded computation

# Explicit sharding control
def single_device_fn(x_shard, w_shard):
    return x_shard @ w_shard

parallel_fn = jax.shard_map(
    single_device_fn,
    mesh=mesh,
    in_specs=(P('data', None), P(None, 'model')),
    out_specs=P('data', 'model')
)

result = parallel_fn(x_sharded, weights_sharded)

Memory Management

Control memory usage and optimize performance through explicit memory management.

def live_arrays() -> list[Array]:
    """
    Get list of arrays currently alive in memory.
    
    Returns:
        List of live Array objects
    """

def clear_caches() -> None:
    """
    Clear JAX's internal caches to free memory.
    
    Clears JIT compilation cache, device buffer cache, and other internal caches.
    """

Configuration and Backend Management

Configure device behavior and backend selection.

# Configuration through jax.config
jax.config.update('jax_platform_name', 'cpu')  # Force CPU backend
jax.config.update('jax_platform_name', 'gpu')  # Force GPU backend  
jax.config.update('jax_platform_name', 'tpu')  # Force TPU backend

# Transfer guards to catch unintentional device transfers
jax.config.update('jax_transfer_guard', 'allow')    # Default: allow all transfers
jax.config.update('jax_transfer_guard', 'log')      # Log transfers  
jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers
jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers

# Default device configuration
jax.config.update('jax_default_device', jax.devices('gpu')[0])  # Set default device

Array and Device Properties

Inspect array placement and device properties.

# Array device methods
array.device() -> Device  # Get device containing array
array.devices() -> set[Device]  # Get all devices for distributed array
array.sharding -> Sharding  # Get array's sharding specification
array.is_fully_replicated -> bool  # Check if array is replicated
array.is_fully_addressable -> bool  # Check if array is fully addressable

# Device properties
class Device:
    """Device object representing compute accelerator."""
    
    platform: str  # Platform name ('cpu', 'gpu', 'tpu')
    device_kind: str  # Device kind string  
    id: int  # Device ID within platform
    host_id: int  # Host ID containing device
    process_index: int  # Process index containing device
    
    def __str__(self) -> str: ...
    def __repr__(self) -> str: ...

Advanced Usage Patterns

Multi-Device Training

# Setup for data-parallel training
def create_train_setup(num_devices):
    # Create mesh for data parallelism
    mesh = jax.make_mesh((num_devices,), ('batch',))
    
    # Sharding specifications
    batch_sharding = NamedSharding(mesh, P('batch'))  # Batch dimension sharded
    replicated_sharding = NamedSharding(mesh, P())    # Parameters replicated
    
    return mesh, batch_sharding, replicated_sharding

def distributed_train_step(params, batch, optimizer_state):
    # All arrays should already have appropriate sharding
    grads = jax.grad(loss_fn)(params, batch)
    
    # Update step automatically uses sharding from inputs
    new_params, new_state = optimizer.update(grads, optimizer_state, params)
    return new_params, new_state

# JIT compile with sharding
distributed_train_step = jax.jit(
    distributed_train_step,
    in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),
    out_shardings=(replicated_sharding, replicated_sharding)
)

Model Parallelism

# Setup for model-parallel computation
def create_model_parallel_setup():
    # 2D mesh: batch x model dimensions
    mesh = jax.make_mesh((2, 4), ('batch', 'model'))
    
    # Different sharding strategies
    input_sharding = NamedSharding(mesh, P('batch', None))
    weight_sharding = NamedSharding(mesh, P(None, 'model'))  
    output_sharding = NamedSharding(mesh, P('batch', 'model'))
    
    return mesh, input_sharding, weight_sharding, output_sharding

def model_parallel_layer(x, weights):
    # Matrix multiply with different sharding patterns
    return x @ weights  # JAX handles the communication automatically

# Shard arrays according to strategy
x = jax.device_put(x, input_sharding)
weights = jax.device_put(weights, weight_sharding)
result = model_parallel_layer(x, weights)  # Result has output_sharding

Memory-Efficient Inference

def memory_efficient_inference(model_fn, large_input):
    # Process in chunks to manage memory
    chunk_size = 1000
    chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]
    
    results = []
    for chunk in chunks:
        # Move to device, compute, move back to host
        device_chunk = jax.device_put(chunk)
        device_result = model_fn(device_chunk)
        host_result = jax.device_get(device_result)
        results.append(host_result)
        
        # Optional: clear caches to free memory
        jax.clear_caches()
    
    return jnp.concatenate(results)

Cross-Device Communication Patterns

# Collective operations using pmap
@jax.pmap
def allreduce_example(x):
    # Sum across all devices
    return jax.lax.psum(x, axis_name='batch')

@jax.pmap  
def allgather_example(x):
    # Gather from all devices
    return jax.lax.all_gather(x, axis_name='batch')

# Use with replicated data
replicated_data = jax.device_put_replicated(data, jax.devices())
summed_result = allreduce_example(replicated_data)
gathered_result = allgather_example(replicated_data)

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json