Differentiate, compile, and transform Numpy code.
—
JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.
import jax
from jax import devices, device_put, make_mesh
from jax.sharding import NamedSharding, PartitionSpec as PQuery available devices and their properties for computation placement and resource management.
def devices(backend=None) -> list[Device]:
"""
Get list of all available devices.
Args:
backend: Optional backend name ('cpu', 'gpu', 'tpu')
Returns:
List of available Device objects
"""
def local_devices(process_index=None, backend=None) -> list[Device]:
"""
Get list of devices local to current process.
Args:
process_index: Process index (None for current process)
backend: Optional backend name
Returns:
List of local Device objects
"""
def device_count(backend=None) -> int:
"""
Get total number of devices across all processes.
Args:
backend: Optional backend name
Returns:
Total device count
"""
def local_device_count(backend=None) -> int:
"""
Get number of devices on current process.
Args:
backend: Optional backend name
Returns:
Local device count
"""
def host_count(backend=None) -> int:
"""
Get number of hosts in distributed computation.
Args:
backend: Optional backend name
Returns:
Host count
"""
def host_id(backend=None) -> int:
"""
Get ID of current host.
Args:
backend: Optional backend name
Returns:
Current host ID
"""
def host_ids(backend=None) -> list[int]:
"""
Get list of all host IDs.
Args:
backend: Optional backend name
Returns:
List of host IDs
"""
def process_count(backend=None) -> int:
"""
Get number of processes in distributed computation.
Args:
backend: Optional backend name
Returns:
Process count
"""
def process_index(backend=None) -> int:
"""
Get index of current process.
Args:
backend: Optional backend name
Returns:
Current process index
"""
def process_indices(backend=None) -> list[int]:
"""
Get list of all process indices.
Args:
backend: Optional backend name
Returns:
List of process indices
"""
def default_backend() -> str:
"""
Get name of default backend.
Returns:
Default backend name string
"""Control where computations run and move data between devices and host memory.
def device_put(x, device=None, src=None) -> Array:
"""
Move array to specified device.
Args:
x: Array or array-like object to move
device: Target device (None for default device)
src: Source device for the transfer
Returns:
Array placed on target device
"""
def device_put_sharded(
sharded_values: list,
devices: list[Device],
indices=None
) -> Array:
"""
Create sharded array from per-device values.
Args:
sharded_values: List of arrays, one per device
devices: List of target devices
indices: Optional sharding indices
Returns:
Distributed array sharded across devices
"""
def device_put_replicated(x, devices: list[Device]) -> Array:
"""
Replicate array across multiple devices.
Args:
x: Array to replicate
devices: List of target devices
Returns:
Array replicated across all specified devices
"""
def device_get(x) -> Any:
"""
Move array from device to host memory as NumPy array.
Args:
x: Array to move to host
Returns:
NumPy array in host memory
"""
def copy_to_host_async(x) -> Any:
"""
Asynchronously copy array to host memory.
Args:
x: Array to copy
Returns:
Future-like object for async copy
"""
def block_until_ready(x) -> Array:
"""
Block until array computation is complete and ready.
Args:
x: Array to wait for
Returns:
The same array, guaranteed to be ready
"""Usage examples:
# Check available devices
all_devices = jax.devices()
print(f"Available devices: {all_devices}")
print(f"Device count: {jax.device_count()}")
# Move data to specific device
cpu_data = jnp.array([1, 2, 3, 4])
if jax.devices('gpu'):
gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])
print(f"Data is on: {gpu_data.device()}")
# Move back to host
host_data = jax.device_get(gpu_data) # Returns NumPy array
# Explicit device placement in computations
with jax.default_device(jax.devices('cpu')[0]):
cpu_result = jnp.sum(jnp.array([1, 2, 3]))Define how arrays are distributed across multiple devices for parallel computation.
class NamedSharding:
"""
Sharding specification using named mesh axes.
Defines how arrays are partitioned across devices using logical axis names.
"""
def __init__(self, mesh, spec):
"""
Create named sharding specification.
Args:
mesh: Device mesh with named axes
spec: Partition specification (PartitionSpec)
"""
self.mesh = mesh
self.spec = spec
class PartitionSpec:
"""
Specification for how to partition array dimensions across mesh axes.
Use P(axis_names...) to create partition specifications.
"""
pass
# Alias for PartitionSpec
P = PartitionSpec
def make_mesh(mesh_shape, axis_names) -> Mesh:
"""
Create device mesh for distributed computation.
Args:
mesh_shape: Shape of device mesh (tuple of integers)
axis_names: Names for mesh axes (tuple of strings)
Returns:
Mesh object representing device layout
"""
class Mesh:
"""Device mesh for distributed computation."""
devices: Array # Device array in mesh shape
axis_names: tuple[str, ...] # Names of mesh axes
@property
def shape(self) -> dict[str, int]:
"""Dictionary mapping axis names to sizes."""
@property
def size(self) -> int:
"""Total number of devices in mesh."""
def make_array_from_single_device_arrays(
arrays: list[Array],
sharding: Sharding
) -> Array:
"""
Create distributed array from per-device arrays.
Args:
arrays: List of arrays on different devices
sharding: Sharding specification
Returns:
Distributed array with specified sharding
"""
def make_array_from_callback(
shape: tuple[int, ...],
sharding: Sharding,
data_callback: Callable
) -> Array:
"""
Create distributed array using callback function.
Args:
shape: Global array shape
sharding: Sharding specification
data_callback: Function to generate data for each shard
Returns:
Distributed array created from callback
"""
def make_array_from_process_local_data(
sharding: Sharding,
local_data: Array
) -> Array:
"""
Create distributed array from process-local data.
Args:
sharding: Sharding specification
local_data: Data local to current process
Returns:
Distributed array assembled from local data
"""Execute computations on sharded arrays with explicit control over parallelization.
def shard_map(
f: Callable,
mesh: Mesh,
in_specs,
out_specs,
check_rep=True
) -> Callable:
"""
Transform function to operate on sharded arrays.
Args:
f: Function to transform
mesh: Device mesh for computation
in_specs: Input sharding specifications
out_specs: Output sharding specifications
check_rep: Whether to check for replication consistency
Returns:
Function that operates on globally sharded arrays
"""
# Alias for shard_map
smap = shard_map
def with_sharding_constraint(x, sharding) -> Array:
"""
Add sharding constraint to array.
Args:
x: Input array
sharding: Desired sharding specification
Returns:
Array with sharding constraint applied
"""Usage examples:
# Create 2x2 device mesh
devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.make_mesh((2, 2), ('data', 'model'))
# Define sharding specifications
data_sharding = NamedSharding(mesh, P('data', None)) # Shard first axis across 'data'
model_sharding = NamedSharding(mesh, P(None, 'model')) # Shard second axis across 'model'
replicated_sharding = NamedSharding(mesh, P()) # Replicated across all devices
# Create sharded arrays
x = jax.random.normal(jax.random.key(0), (8, 4))
x_sharded = jax.device_put(x, data_sharding)
weights = jax.random.normal(jax.random.key(1), (4, 8))
weights_sharded = jax.device_put(weights, model_sharding)
# Computation with sharded arrays automatically parallelized
@jax.jit
def matmul_fn(x, w):
return x @ w
result = matmul_fn(x_sharded, weights_sharded) # Automatically sharded computation
# Explicit sharding control
def single_device_fn(x_shard, w_shard):
return x_shard @ w_shard
parallel_fn = jax.shard_map(
single_device_fn,
mesh=mesh,
in_specs=(P('data', None), P(None, 'model')),
out_specs=P('data', 'model')
)
result = parallel_fn(x_sharded, weights_sharded)Control memory usage and optimize performance through explicit memory management.
def live_arrays() -> list[Array]:
"""
Get list of arrays currently alive in memory.
Returns:
List of live Array objects
"""
def clear_caches() -> None:
"""
Clear JAX's internal caches to free memory.
Clears JIT compilation cache, device buffer cache, and other internal caches.
"""Configure device behavior and backend selection.
# Configuration through jax.config
jax.config.update('jax_platform_name', 'cpu') # Force CPU backend
jax.config.update('jax_platform_name', 'gpu') # Force GPU backend
jax.config.update('jax_platform_name', 'tpu') # Force TPU backend
# Transfer guards to catch unintentional device transfers
jax.config.update('jax_transfer_guard', 'allow') # Default: allow all transfers
jax.config.update('jax_transfer_guard', 'log') # Log transfers
jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers
jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers
# Default device configuration
jax.config.update('jax_default_device', jax.devices('gpu')[0]) # Set default deviceInspect array placement and device properties.
# Array device methods
array.device() -> Device # Get device containing array
array.devices() -> set[Device] # Get all devices for distributed array
array.sharding -> Sharding # Get array's sharding specification
array.is_fully_replicated -> bool # Check if array is replicated
array.is_fully_addressable -> bool # Check if array is fully addressable
# Device properties
class Device:
"""Device object representing compute accelerator."""
platform: str # Platform name ('cpu', 'gpu', 'tpu')
device_kind: str # Device kind string
id: int # Device ID within platform
host_id: int # Host ID containing device
process_index: int # Process index containing device
def __str__(self) -> str: ...
def __repr__(self) -> str: ...# Setup for data-parallel training
def create_train_setup(num_devices):
# Create mesh for data parallelism
mesh = jax.make_mesh((num_devices,), ('batch',))
# Sharding specifications
batch_sharding = NamedSharding(mesh, P('batch')) # Batch dimension sharded
replicated_sharding = NamedSharding(mesh, P()) # Parameters replicated
return mesh, batch_sharding, replicated_sharding
def distributed_train_step(params, batch, optimizer_state):
# All arrays should already have appropriate sharding
grads = jax.grad(loss_fn)(params, batch)
# Update step automatically uses sharding from inputs
new_params, new_state = optimizer.update(grads, optimizer_state, params)
return new_params, new_state
# JIT compile with sharding
distributed_train_step = jax.jit(
distributed_train_step,
in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),
out_shardings=(replicated_sharding, replicated_sharding)
)# Setup for model-parallel computation
def create_model_parallel_setup():
# 2D mesh: batch x model dimensions
mesh = jax.make_mesh((2, 4), ('batch', 'model'))
# Different sharding strategies
input_sharding = NamedSharding(mesh, P('batch', None))
weight_sharding = NamedSharding(mesh, P(None, 'model'))
output_sharding = NamedSharding(mesh, P('batch', 'model'))
return mesh, input_sharding, weight_sharding, output_sharding
def model_parallel_layer(x, weights):
# Matrix multiply with different sharding patterns
return x @ weights # JAX handles the communication automatically
# Shard arrays according to strategy
x = jax.device_put(x, input_sharding)
weights = jax.device_put(weights, weight_sharding)
result = model_parallel_layer(x, weights) # Result has output_shardingdef memory_efficient_inference(model_fn, large_input):
# Process in chunks to manage memory
chunk_size = 1000
chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]
results = []
for chunk in chunks:
# Move to device, compute, move back to host
device_chunk = jax.device_put(chunk)
device_result = model_fn(device_chunk)
host_result = jax.device_get(device_result)
results.append(host_result)
# Optional: clear caches to free memory
jax.clear_caches()
return jnp.concatenate(results)# Collective operations using pmap
@jax.pmap
def allreduce_example(x):
# Sum across all devices
return jax.lax.psum(x, axis_name='batch')
@jax.pmap
def allgather_example(x):
# Gather from all devices
return jax.lax.all_gather(x, axis_name='batch')
# Use with replicated data
replicated_data = jax.device_put_replicated(data, jax.devices())
summed_result = allreduce_example(replicated_data)
gathered_result = allgather_example(replicated_data)Install with Tessl CLI
npx tessl i tessl/pypi-jax