or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md
tile.json

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/jax@0.7.x

To install, run

npx @tessl/cli install tessl/pypi-jax@0.7.0

index.mddocs/

JAX

JAX is a NumPy-compatible library that provides composable transformations of Python+NumPy programs: differentiate, compile, and transform Numpy code. JAX brings together a powerful ecosystem of program transformations including automatic differentiation (grad), just-in-time compilation (jit), vectorization (vmap), and parallelization (pmap) with support for CPUs, GPUs, and TPUs.

Package Information

  • Package Name: jax
  • Language: Python
  • Installation: pip install jax[cpu] (CPU) or pip install jax[cuda12] (GPU)

Core Imports

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap

Import specific transformations:

from jax import (
    grad, jit, vmap, pmap, jacfwd, jacrev, 
    hessian, value_and_grad, checkpoint
)

Import array types and devices:

from jax import Array, Device
import jax.numpy as jnp
import jax.random as jr
import jax.lax as lax
import jax.scipy as jsp
import jax.nn as jnn
import jax.tree as tree

Basic Usage

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap

# NumPy-compatible arrays and operations
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.sum(x ** 2)  # JAX arrays work like NumPy

# Automatic differentiation
def loss_fn(params, x, y):
    pred = params[0] * x + params[1]
    return jnp.mean((pred - y) ** 2)

# Compute gradient of loss function
grad_fn = grad(loss_fn)
params = jnp.array([0.5, 0.1])
gradients = grad_fn(params, x, y)

# Just-in-time compilation for performance
@jit
def fast_function(x):
    return jnp.sum(x ** 2) + jnp.sin(x).sum()

result = fast_function(x)

# Vectorization across batch dimension
@vmap
def process_batch(single_input):
    return single_input ** 2 + jnp.sin(single_input)

batch_data = jnp.array([[1, 2], [3, 4], [5, 6]])
batch_result = process_batch(batch_data)

# Random number generation
key = jax.random.key(42)
random_data = jax.random.normal(key, (10, 5))

# Device management
print(f"Available devices: {jax.devices()}")
array_on_gpu = jax.device_put(x, jax.devices()[0])

Architecture

JAX's power comes from its composable function transformations that can be applied to pure Python functions:

  • Pure Functions: JAX transformations require functions to be functionally pure (no side effects)
  • Function Transformations: grad, jit, vmap, pmap can be arbitrarily composed
  • XLA Compilation: Just-in-time compilation to optimized accelerator code
  • Array Programming: NumPy-compatible array operations with immutable semantics
  • Device Model: Transparent execution across CPU, GPU, and TPU with explicit device management

The composability enables powerful patterns like jit(grad(loss_fn)) or vmap(grad(per_example_loss)).

Capabilities

Core Program Transformations

The fundamental JAX transformations that enable automatic differentiation, compilation, vectorization, and parallelization. These transformations are the core of JAX's power and can be arbitrarily composed.

def jit(fun: Callable, **kwargs) -> Callable: ...
def grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...
def vmap(fun: Callable, in_axes=0, out_axes=0, **kwargs) -> Callable: ...
def pmap(fun: Callable, axis_name=None, **kwargs) -> Callable: ...
def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...

Core Transformations

NumPy Compatibility API

Complete NumPy-compatible array operations including creation, manipulation, mathematical functions, linear algebra, and reductions. JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation and automatic differentiation.

# Array creation
def array(object, dtype=None, **kwargs) -> Array: ...
def zeros(shape, dtype=None) -> Array: ...
def ones(shape, dtype=None) -> Array: ...
def arange(start, stop=None, step=None, dtype=None) -> Array: ...

# Mathematical operations
def sum(a, axis=None, **kwargs) -> Array: ...
def mean(a, axis=None, **kwargs) -> Array: ...
def dot(a, b) -> Array: ...
def matmul(x1, x2) -> Array: ...

NumPy Compatibility

Neural Network Functions

Activation functions, initializers, and neural network utilities commonly used in machine learning. Includes all standard activations like ReLU, sigmoid, softmax, and modern variants like GELU, Swish, and attention mechanisms.

def relu(x) -> Array: ...
def sigmoid(x) -> Array: ...
def softmax(x, axis=-1) -> Array: ...
def gelu(x, approximate=True) -> Array: ...
def silu(x) -> Array: ...
def one_hot(x, num_classes, **kwargs) -> Array: ...
def dot_product_attention(query, key, value, **kwargs) -> Array: ...

Neural Networks

Random Number Generation

Functional pseudo-random number generation with explicit key management. JAX uses a functional approach to random numbers that enables reproducibility, parallelization, and vectorization.

def key(seed: int) -> Array: ...
def split(key: Array, num: int = 2) -> Array: ...
def normal(key: Array, shape=(), dtype=float) -> Array: ...
def uniform(key: Array, shape=(), minval=0.0, maxval=1.0) -> Array: ...
def categorical(key: Array, logits, **kwargs) -> Array: ...
def choice(key: Array, a, **kwargs) -> Array: ...

Random Numbers

Low-Level Operations

Direct XLA operations and primitives for high-performance computing. These provide the building blocks for JAX's higher-level operations and enable custom operations and optimizations.

def add(x, y) -> Array: ...
def mul(x, y) -> Array: ...
def dot_general(lhs, rhs, dimension_numbers, **kwargs) -> Array: ...
def conv_general_dilated(lhs, rhs, **kwargs) -> Array: ...
def reduce_sum(operand, axes) -> Array: ...
def cond(pred, true_fun, false_fun, *operands) -> Any: ...
def while_loop(cond_fun, body_fun, init_val) -> Any: ...
def scan(f, init, xs, **kwargs) -> tuple[Any, Array]: ...

Low-Level Operations

SciPy Compatibility

SciPy-compatible functions for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. Provides a familiar interface for scientific Python users.

# Linear algebra (jax.scipy.linalg)
def solve(a, b) -> Array: ...
def eig(a, **kwargs) -> tuple[Array, Array]: ...
def svd(a, **kwargs) -> tuple[Array, Array, Array]: ...

# Special functions (jax.scipy.special)  
def logsumexp(a, **kwargs) -> Array: ...
def erf(x) -> Array: ...
def gamma(x) -> Array: ...

# Statistics (jax.scipy.stats)
def norm.pdf(x, loc=0, scale=1) -> Array: ...
def multivariate_normal.pdf(x, mean, cov) -> Array: ...

SciPy Compatibility

Tree Operations

Utilities for working with PyTrees (nested Python structures containing arrays). Essential for handling complex data structures in functional programming patterns and neural network parameters.

def tree_map(f, tree, *rest) -> Any: ...
def tree_reduce(function, tree, **kwargs) -> Any: ...
def tree_flatten(tree) -> tuple[list, Any]: ...
def tree_unflatten(treedef, leaves) -> Any: ...
def tree_leaves(tree) -> list: ...
def tree_structure(tree) -> Any: ...

Tree Operations

Device and Memory Management

Device placement, memory management, and distributed computing primitives. Enables efficient use of accelerators and scaling across multiple devices.

def devices() -> list[Device]: ...
def device_put(x, device=None) -> Array: ...
def device_get(x) -> Any: ...
class NamedSharding: ...
def make_mesh(*mesh_axes, axis_names=None) -> Mesh: ...
def shard_map(f, mesh, in_specs, out_specs, **kwargs) -> Callable: ...

Device and Memory Management

Experimental Features

Cutting-edge and experimental JAX features including new APIs, performance optimizations, and research capabilities. These features may change in future versions.

def io_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...
def enable_x64(enable=True) -> None: ...
class MutableArray: ...
def saved_input_vjp(f, *primals) -> tuple[Any, Callable]: ...

Experimental Features

Core Types

class Array:
    """JAX array type for numerical computing."""
    shape: tuple[int, ...]
    dtype: numpy.dtype
    size: int
    ndim: int
    
    def __array__(self) -> numpy.ndarray: ...
    def __getitem__(self, key) -> Array: ...
    def astype(self, dtype) -> Array: ...
    def reshape(self, *shape) -> Array: ...
    def transpose(self, *axes) -> Array: ...

class Device:
    """Device abstraction for accelerators."""
    platform: str
    device_kind: str
    id: int
    host_id: int
    
class ShapeDtypeStruct:
    """Shape and dtype structure for abstract evaluation."""
    shape: tuple[int, ...]
    dtype: numpy.dtype
    
    def __init__(self, shape, dtype): ...

PRNGKeyArray = Array  # Type alias for PRNG keys

Configuration and Debugging

# Configuration flags
jax.config.update('jax_enable_x64', True)  # Enable 64-bit precision
jax.config.update('jax_debug_nans', True)  # Debug NaN values
jax.config.update('jax_debug_infs', True)  # Debug Inf values
jax.config.update('jax_platform_name', 'cpu')  # Force platform
jax.config.update('jax_default_device', device)  # Set default device
jax.config.update('jax_compilation_cache_dir', '/path/to/cache')  # Cache directory
jax.config.update('jax_disable_jit', True)  # Disable JIT globally
jax.config.update('jax_log_compiles', True)  # Log compilation events

# Core utilities and debugging
def typeof(x) -> Any: ...
def live_arrays() -> list[Array]: ...
def clear_caches() -> None: ...
def make_jaxpr(fun) -> Callable: ...
def eval_shape(fun, *args, **kwargs) -> Any: ...
def print_environment_info() -> None: ...
def ensure_compile_time_eval() -> None: ...
def pure_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...
def effects_barrier() -> None: ...
def named_call(f, *, name: str) -> Callable: ...
def named_scope(name: str): ...
def disable_jit(disable: bool = True): ...

# Memory and performance utilities
def device_count_per_host() -> int: ...
def host_callback(callback, result_shape, *args, **kwargs) -> Any: ...
def make_mesh(*mesh_axes, axis_names=None) -> Any: ...
def with_sharding_constraint(x, constraint) -> Array: ...

# Advanced debugging
def debug_print(fmt: str, *args) -> None: ...
def debug_callback(callback, *args) -> None: ...
def debug_key_reuse(enable: bool = True) -> None: ...