CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

experimental.mddocs/

Experimental Features

JAX experimental features provide access to cutting-edge capabilities, performance optimizations, and research functionality through jax.experimental. These features may change or be moved to the main JAX API in future versions.

Warning: Experimental APIs may change without notice between JAX versions. Use with caution in production code.

Core Imports

import jax.experimental as jex
from jax.experimental import io_callback, enable_x64

Capabilities

Precision Control

Control floating-point precision globally across JAX computations.

def enable_x64(enable: bool = True) -> None:
    """
    Enable or disable 64-bit floating point precision.
    
    Args:
        enable: Whether to enable 64-bit precision (default: True)
        
    Note:
        This sets jax_enable_x64 config flag globally
    """

def disable_x64() -> None:
    """
    Disable 64-bit floating point precision.
    
    Convenience function equivalent to enable_x64(False).
    """

Usage examples:

# Enable double precision
jax.experimental.enable_x64()
x = jnp.array(1.0)  # Now defaults to float64 instead of float32
print(x.dtype)  # dtype('float64')

# Disable double precision  
jax.experimental.disable_x64()
y = jnp.array(1.0)  # Back to float32
print(y.dtype)  # dtype('float32')

I/O and Callbacks

Enable host callbacks for I/O operations and side effects within JAX computations.

def io_callback(
    callback: Callable,
    result_shape_dtypes,
    *args,
    sharding=None,
    vmap_method=None,
    ordered=False,
    **kwargs
) -> Any:
    """
    Call host function from within JAX computation with I/O side effects.
    
    Args:
        callback: Host function to call (should be pure except for I/O)
        result_shape_dtypes: Shape and dtype specification for callback result
        args: Arguments to pass to callback
        sharding: Sharding specification for result
        vmap_method: How to handle vmapping ('sequential', 'expand_dims', etc.)
        ordered: Whether to maintain call ordering across devices
        kwargs: Additional keyword arguments for callback
        
    Returns:
        Result of callback with specified shape and dtype
    """

Usage examples:

# Logging during computation (debugging)
def log_value(x, step):
    print(f"Step {step}: value = {x}")
    return x

@jax.jit  
def training_step(x, step):
    # Log intermediate values during training
    x = jax.experimental.io_callback(
        log_value, 
        jax.ShapeDtypeStruct(x.shape, x.dtype),
        x, step
    )
    return x * 2

# File I/O during computation
def save_checkpoint(params, step):
    import pickle
    with open(f'checkpoint_{step}.pkl', 'wb') as f:
        pickle.dump(params, f)
    return step

@jax.jit
def train_with_checkpointing(params, data, step):
    # Training computation
    loss = compute_loss(params, data)
    grads = jax.grad(compute_loss)(params, data)
    new_params = update_params(params, grads)
    
    # Save checkpoint every 100 steps
    step = jax.experimental.io_callback(
        save_checkpoint,
        jax.ShapeDtypeStruct((), jnp.int32),
        new_params, step
    )
    
    return new_params, loss

Advanced Differentiation

Experimental differentiation features and optimizations.

def saved_input_vjp(f, *primals) -> tuple[Any, Callable]:
    """
    Vector-Jacobian product with saved inputs for memory efficiency.
    
    Args:
        f: Function to differentiate
        primals: Input values
        
    Returns:
        Tuple of (primal_out, vjp_fun) where vjp_fun has access to saved inputs
    """

# Alias for saved_input_vjp
si_vjp = saved_input_vjp

Usage example:

def expensive_function(x, y):
    # Some expensive computation that we want to differentiate
    z = jnp.exp(x) + jnp.sin(y) 
    return jnp.sum(z ** 2)

# Use saved input VJP for memory efficiency
x, y = jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])
primal_out, vjp_fn = jax.experimental.saved_input_vjp(expensive_function, x, y)

# Compute VJP with cotangent
cotangent = 1.0
x_grad, y_grad = vjp_fn(cotangent)

Extended Array Types

Experimental array types and extended functionality.

class EArray:
    """
    Extended array type with additional metadata and functionality.
    
    Experimental array type that may include additional features
    beyond standard JAX arrays.
    """
    pass

class MutableArray:
    """
    Experimental mutable array type for specific use cases.
    
    Warning: Breaks JAX's functional programming model. Use carefully.
    """
    pass

def mutable_array(init_val) -> MutableArray:
    """
    Create mutable array from initial value.
    
    Args:
        init_val: Initial array value
        
    Returns:
        MutableArray that can be modified in-place
    """

Type System Extensions

Experimental extensions to JAX's type system.

def primal_tangent_dtype(primal_dtype, tangent_dtype=None):
    """
    Create dtype for primal-tangent pairs in forward-mode AD.
    
    Args:
        primal_dtype: Data type for primal values
        tangent_dtype: Data type for tangent values (defaults to primal_dtype)
        
    Returns:
        Combined dtype for primal-tangent computation
    """

Compilation and Performance

Experimental compilation features and performance optimizations.

# Compilation control
def disable_jit_cache() -> None:
    """Disable JIT compilation cache for debugging."""

def enable_jit_cache() -> None:
    """Re-enable JIT compilation cache."""

# Performance monitoring  
def compilation_cache_stats() -> dict:
    """Get statistics about JIT compilation cache."""

def clear_compilation_cache() -> None:
    """Clear JIT compilation cache."""

Hardware-Specific Features

Experimental features for specific hardware accelerators.

# TPU-specific features
class TPUMemoryFraction:
    """Control TPU memory usage fraction."""
    
def set_tpu_memory_fraction(fraction: float) -> None:
    """
    Set fraction of TPU memory to use.
    
    Args:
        fraction: Memory fraction (0.0 to 1.0)
    """

# GPU-specific features
def gpu_memory_stats() -> dict:
    """Get GPU memory usage statistics."""

def set_gpu_memory_growth(enable: bool) -> None:
    """
    Enable/disable GPU memory growth.
    
    Args:
        enable: Whether to enable incremental memory allocation
    """

Automatic Mixed Precision

Experimental automatic mixed precision for training acceleration.

class AutoMixedPrecision:
    """Automatic mixed precision policy for training."""
    
    def __init__(self, policy='float16'):
        """
        Initialize AMP policy.
        
        Args:
            policy: Precision policy ('float16', 'bfloat16', etc.)
        """
        self.policy = policy
    
    def __call__(self, fn):
        """Apply AMP to function."""
        pass

def amp_policy(policy_name: str) -> AutoMixedPrecision:
    """
    Create automatic mixed precision policy.
    
    Args:
        policy_name: Name of precision policy
        
    Returns:
        AMP policy object
    """

Distributed Computing Extensions

Experimental distributed computing features beyond standard pmap/shard_map.

def multi_host_utils():
    """Utilities for multi-host distributed computation."""
    pass

class GlobalDeviceArray:
    """
    Experimental global device array for large-scale distributed computation.
    
    Represents arrays that span multiple hosts in distributed setting.
    """
    pass

def create_global_device_array(
    shape, 
    dtype, 
    mesh, 
    partition_spec
) -> GlobalDeviceArray:
    """
    Create global device array across distributed system.
    
    Args:
        shape: Global array shape
        dtype: Array data type
        mesh: Device mesh specification
        partition_spec: How to partition array
        
    Returns:
        Global device array
    """

Research and Prototype Features

Cutting-edge research features that may be highly experimental.

# Sparsity support
class SparseArray:
    """Experimental sparse array support."""
    pass

def sparse_ops():
    """Sparse operations module (highly experimental)."""
    pass

# Quantization support  
def quantized_dot(lhs, rhs, **kwargs):
    """Experimental quantized matrix multiplication."""
    pass

def quantization_utils():
    """Utilities for quantized computation."""
    pass

# Custom operators
def custom_op_builder():
    """Builder for custom XLA operations."""
    pass

# Advanced compilation
def ahead_of_time_compile(fn, *args, **kwargs):
    """Ahead-of-time compilation (experimental)."""
    pass

Debugging and Profiling

Experimental debugging and profiling tools.

def debug_callback(callback, *args, **kwargs):
    """
    Debug callback that doesn't affect computation graph.
    
    Args:
        callback: Debug function to call
        args: Arguments to callback
        kwargs: Keyword arguments to callback
    """

def trace_function(fn):
    """
    Trace function execution for debugging.
    
    Args:
        fn: Function to trace
        
    Returns:
        Traced version of function
    """

def memory_profiler():
    """Memory profiling utilities."""
    pass

def computation_graph_visualizer():
    """Tools for visualizing computation graphs."""
    pass

Migration Patterns

When experimental features graduate to main JAX API:

# Old experimental usage
from jax.experimental import feature_name

# New main API usage (after graduation)  
from jax import feature_name

# Or sometimes moves to different module
from jax.some_module import feature_name

Usage Guidelines

Best Practices for Experimental Features

# 1. Version pinning when using experimental features
# requirements.txt: jax==0.7.1  # Pin exact version

# 2. Graceful fallbacks
try:
    from jax.experimental import new_feature
    use_experimental = True
except ImportError:
    use_experimental = False

def my_function(x):
    if use_experimental:
        return new_feature.optimized_op(x)
    else:
        return traditional_op(x)

# 3. Feature flags for experimental code
USE_EXPERIMENTAL_AMP = False

if USE_EXPERIMENTAL_AMP:
    amp_policy = jax.experimental.amp_policy('float16')
    train_fn = amp_policy(train_fn)

# 4. Documentation and warnings
def experimental_model_fn(x):
    """
    Model function using experimental JAX features.
    
    Warning: Uses jax.experimental.* APIs that may change.
    Tested with JAX v0.7.1.
    """
    # Implementation using experimental features
    pass

Testing Experimental Features

import pytest

# Skip tests if experimental feature not available
@pytest.mark.skipif(
    not hasattr(jax.experimental, 'new_feature'),
    reason="Experimental feature not available"
)
def test_experimental_feature():
    # Test experimental functionality
    pass

# Conditional testing based on JAX version
import jax
jax_version = tuple(map(int, jax.__version__.split('.')[:2]))

@pytest.mark.skipif(
    jax_version < (0, 7), 
    reason="Feature requires JAX >= 0.7"
)
def test_version_dependent_feature():
    # Test version-dependent experimental feature
    pass

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json