Differentiate, compile, and transform Numpy code.
—
JAX experimental features provide access to cutting-edge capabilities, performance optimizations, and research functionality through jax.experimental. These features may change or be moved to the main JAX API in future versions.
Warning: Experimental APIs may change without notice between JAX versions. Use with caution in production code.
import jax.experimental as jex
from jax.experimental import io_callback, enable_x64Control floating-point precision globally across JAX computations.
def enable_x64(enable: bool = True) -> None:
"""
Enable or disable 64-bit floating point precision.
Args:
enable: Whether to enable 64-bit precision (default: True)
Note:
This sets jax_enable_x64 config flag globally
"""
def disable_x64() -> None:
"""
Disable 64-bit floating point precision.
Convenience function equivalent to enable_x64(False).
"""Usage examples:
# Enable double precision
jax.experimental.enable_x64()
x = jnp.array(1.0) # Now defaults to float64 instead of float32
print(x.dtype) # dtype('float64')
# Disable double precision
jax.experimental.disable_x64()
y = jnp.array(1.0) # Back to float32
print(y.dtype) # dtype('float32')Enable host callbacks for I/O operations and side effects within JAX computations.
def io_callback(
callback: Callable,
result_shape_dtypes,
*args,
sharding=None,
vmap_method=None,
ordered=False,
**kwargs
) -> Any:
"""
Call host function from within JAX computation with I/O side effects.
Args:
callback: Host function to call (should be pure except for I/O)
result_shape_dtypes: Shape and dtype specification for callback result
args: Arguments to pass to callback
sharding: Sharding specification for result
vmap_method: How to handle vmapping ('sequential', 'expand_dims', etc.)
ordered: Whether to maintain call ordering across devices
kwargs: Additional keyword arguments for callback
Returns:
Result of callback with specified shape and dtype
"""Usage examples:
# Logging during computation (debugging)
def log_value(x, step):
print(f"Step {step}: value = {x}")
return x
@jax.jit
def training_step(x, step):
# Log intermediate values during training
x = jax.experimental.io_callback(
log_value,
jax.ShapeDtypeStruct(x.shape, x.dtype),
x, step
)
return x * 2
# File I/O during computation
def save_checkpoint(params, step):
import pickle
with open(f'checkpoint_{step}.pkl', 'wb') as f:
pickle.dump(params, f)
return step
@jax.jit
def train_with_checkpointing(params, data, step):
# Training computation
loss = compute_loss(params, data)
grads = jax.grad(compute_loss)(params, data)
new_params = update_params(params, grads)
# Save checkpoint every 100 steps
step = jax.experimental.io_callback(
save_checkpoint,
jax.ShapeDtypeStruct((), jnp.int32),
new_params, step
)
return new_params, lossExperimental differentiation features and optimizations.
def saved_input_vjp(f, *primals) -> tuple[Any, Callable]:
"""
Vector-Jacobian product with saved inputs for memory efficiency.
Args:
f: Function to differentiate
primals: Input values
Returns:
Tuple of (primal_out, vjp_fun) where vjp_fun has access to saved inputs
"""
# Alias for saved_input_vjp
si_vjp = saved_input_vjpUsage example:
def expensive_function(x, y):
# Some expensive computation that we want to differentiate
z = jnp.exp(x) + jnp.sin(y)
return jnp.sum(z ** 2)
# Use saved input VJP for memory efficiency
x, y = jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])
primal_out, vjp_fn = jax.experimental.saved_input_vjp(expensive_function, x, y)
# Compute VJP with cotangent
cotangent = 1.0
x_grad, y_grad = vjp_fn(cotangent)Experimental array types and extended functionality.
class EArray:
"""
Extended array type with additional metadata and functionality.
Experimental array type that may include additional features
beyond standard JAX arrays.
"""
pass
class MutableArray:
"""
Experimental mutable array type for specific use cases.
Warning: Breaks JAX's functional programming model. Use carefully.
"""
pass
def mutable_array(init_val) -> MutableArray:
"""
Create mutable array from initial value.
Args:
init_val: Initial array value
Returns:
MutableArray that can be modified in-place
"""Experimental extensions to JAX's type system.
def primal_tangent_dtype(primal_dtype, tangent_dtype=None):
"""
Create dtype for primal-tangent pairs in forward-mode AD.
Args:
primal_dtype: Data type for primal values
tangent_dtype: Data type for tangent values (defaults to primal_dtype)
Returns:
Combined dtype for primal-tangent computation
"""Experimental compilation features and performance optimizations.
# Compilation control
def disable_jit_cache() -> None:
"""Disable JIT compilation cache for debugging."""
def enable_jit_cache() -> None:
"""Re-enable JIT compilation cache."""
# Performance monitoring
def compilation_cache_stats() -> dict:
"""Get statistics about JIT compilation cache."""
def clear_compilation_cache() -> None:
"""Clear JIT compilation cache."""Experimental features for specific hardware accelerators.
# TPU-specific features
class TPUMemoryFraction:
"""Control TPU memory usage fraction."""
def set_tpu_memory_fraction(fraction: float) -> None:
"""
Set fraction of TPU memory to use.
Args:
fraction: Memory fraction (0.0 to 1.0)
"""
# GPU-specific features
def gpu_memory_stats() -> dict:
"""Get GPU memory usage statistics."""
def set_gpu_memory_growth(enable: bool) -> None:
"""
Enable/disable GPU memory growth.
Args:
enable: Whether to enable incremental memory allocation
"""Experimental automatic mixed precision for training acceleration.
class AutoMixedPrecision:
"""Automatic mixed precision policy for training."""
def __init__(self, policy='float16'):
"""
Initialize AMP policy.
Args:
policy: Precision policy ('float16', 'bfloat16', etc.)
"""
self.policy = policy
def __call__(self, fn):
"""Apply AMP to function."""
pass
def amp_policy(policy_name: str) -> AutoMixedPrecision:
"""
Create automatic mixed precision policy.
Args:
policy_name: Name of precision policy
Returns:
AMP policy object
"""Experimental distributed computing features beyond standard pmap/shard_map.
def multi_host_utils():
"""Utilities for multi-host distributed computation."""
pass
class GlobalDeviceArray:
"""
Experimental global device array for large-scale distributed computation.
Represents arrays that span multiple hosts in distributed setting.
"""
pass
def create_global_device_array(
shape,
dtype,
mesh,
partition_spec
) -> GlobalDeviceArray:
"""
Create global device array across distributed system.
Args:
shape: Global array shape
dtype: Array data type
mesh: Device mesh specification
partition_spec: How to partition array
Returns:
Global device array
"""Cutting-edge research features that may be highly experimental.
# Sparsity support
class SparseArray:
"""Experimental sparse array support."""
pass
def sparse_ops():
"""Sparse operations module (highly experimental)."""
pass
# Quantization support
def quantized_dot(lhs, rhs, **kwargs):
"""Experimental quantized matrix multiplication."""
pass
def quantization_utils():
"""Utilities for quantized computation."""
pass
# Custom operators
def custom_op_builder():
"""Builder for custom XLA operations."""
pass
# Advanced compilation
def ahead_of_time_compile(fn, *args, **kwargs):
"""Ahead-of-time compilation (experimental)."""
passExperimental debugging and profiling tools.
def debug_callback(callback, *args, **kwargs):
"""
Debug callback that doesn't affect computation graph.
Args:
callback: Debug function to call
args: Arguments to callback
kwargs: Keyword arguments to callback
"""
def trace_function(fn):
"""
Trace function execution for debugging.
Args:
fn: Function to trace
Returns:
Traced version of function
"""
def memory_profiler():
"""Memory profiling utilities."""
pass
def computation_graph_visualizer():
"""Tools for visualizing computation graphs."""
passWhen experimental features graduate to main JAX API:
# Old experimental usage
from jax.experimental import feature_name
# New main API usage (after graduation)
from jax import feature_name
# Or sometimes moves to different module
from jax.some_module import feature_name# 1. Version pinning when using experimental features
# requirements.txt: jax==0.7.1 # Pin exact version
# 2. Graceful fallbacks
try:
from jax.experimental import new_feature
use_experimental = True
except ImportError:
use_experimental = False
def my_function(x):
if use_experimental:
return new_feature.optimized_op(x)
else:
return traditional_op(x)
# 3. Feature flags for experimental code
USE_EXPERIMENTAL_AMP = False
if USE_EXPERIMENTAL_AMP:
amp_policy = jax.experimental.amp_policy('float16')
train_fn = amp_policy(train_fn)
# 4. Documentation and warnings
def experimental_model_fn(x):
"""
Model function using experimental JAX features.
Warning: Uses jax.experimental.* APIs that may change.
Tested with JAX v0.7.1.
"""
# Implementation using experimental features
passimport pytest
# Skip tests if experimental feature not available
@pytest.mark.skipif(
not hasattr(jax.experimental, 'new_feature'),
reason="Experimental feature not available"
)
def test_experimental_feature():
# Test experimental functionality
pass
# Conditional testing based on JAX version
import jax
jax_version = tuple(map(int, jax.__version__.split('.')[:2]))
@pytest.mark.skipif(
jax_version < (0, 7),
reason="Feature requires JAX >= 0.7"
)
def test_version_dependent_feature():
# Test version-dependent experimental feature
passInstall with Tessl CLI
npx tessl i tessl/pypi-jax