Differentiate, compile, and transform Numpy code.
npx @tessl/cli install tessl/pypi-jax@0.7.00
# JAX
1
2
JAX is a NumPy-compatible library that provides composable transformations of Python+NumPy programs: differentiate, compile, and transform Numpy code. JAX brings together a powerful ecosystem of program transformations including automatic differentiation (grad), just-in-time compilation (jit), vectorization (vmap), and parallelization (pmap) with support for CPUs, GPUs, and TPUs.
3
4
## Package Information
5
6
- **Package Name**: jax
7
- **Language**: Python
8
- **Installation**: `pip install jax[cpu]` (CPU) or `pip install jax[cuda12]` (GPU)
9
10
## Core Imports
11
12
```python
13
import jax
14
import jax.numpy as jnp
15
from jax import grad, jit, vmap, pmap
16
```
17
18
Import specific transformations:
19
20
```python
21
from jax import (
22
grad, jit, vmap, pmap, jacfwd, jacrev,
23
hessian, value_and_grad, checkpoint
24
)
25
```
26
27
Import array types and devices:
28
29
```python
30
from jax import Array, Device
31
import jax.numpy as jnp
32
import jax.random as jr
33
import jax.lax as lax
34
import jax.scipy as jsp
35
import jax.nn as jnn
36
import jax.tree as tree
37
```
38
39
## Basic Usage
40
41
```python
42
import jax
43
import jax.numpy as jnp
44
from jax import grad, jit, vmap
45
46
# NumPy-compatible arrays and operations
47
x = jnp.array([1.0, 2.0, 3.0, 4.0])
48
y = jnp.sum(x ** 2) # JAX arrays work like NumPy
49
50
# Automatic differentiation
51
def loss_fn(params, x, y):
52
pred = params[0] * x + params[1]
53
return jnp.mean((pred - y) ** 2)
54
55
# Compute gradient of loss function
56
grad_fn = grad(loss_fn)
57
params = jnp.array([0.5, 0.1])
58
gradients = grad_fn(params, x, y)
59
60
# Just-in-time compilation for performance
61
@jit
62
def fast_function(x):
63
return jnp.sum(x ** 2) + jnp.sin(x).sum()
64
65
result = fast_function(x)
66
67
# Vectorization across batch dimension
68
@vmap
69
def process_batch(single_input):
70
return single_input ** 2 + jnp.sin(single_input)
71
72
batch_data = jnp.array([[1, 2], [3, 4], [5, 6]])
73
batch_result = process_batch(batch_data)
74
75
# Random number generation
76
key = jax.random.key(42)
77
random_data = jax.random.normal(key, (10, 5))
78
79
# Device management
80
print(f"Available devices: {jax.devices()}")
81
array_on_gpu = jax.device_put(x, jax.devices()[0])
82
```
83
84
## Architecture
85
86
JAX's power comes from its composable function transformations that can be applied to pure Python functions:
87
88
- **Pure Functions**: JAX transformations require functions to be functionally pure (no side effects)
89
- **Function Transformations**: grad, jit, vmap, pmap can be arbitrarily composed
90
- **XLA Compilation**: Just-in-time compilation to optimized accelerator code
91
- **Array Programming**: NumPy-compatible array operations with immutable semantics
92
- **Device Model**: Transparent execution across CPU, GPU, and TPU with explicit device management
93
94
The composability enables powerful patterns like `jit(grad(loss_fn))` or `vmap(grad(per_example_loss))`.
95
96
## Capabilities
97
98
### Core Program Transformations
99
100
The fundamental JAX transformations that enable automatic differentiation, compilation, vectorization, and parallelization. These transformations are the core of JAX's power and can be arbitrarily composed.
101
102
```python { .api }
103
def jit(fun: Callable, **kwargs) -> Callable: ...
104
def grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...
105
def vmap(fun: Callable, in_axes=0, out_axes=0, **kwargs) -> Callable: ...
106
def pmap(fun: Callable, axis_name=None, **kwargs) -> Callable: ...
107
def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...
108
```
109
110
[Core Transformations](./core-transformations.md)
111
112
### NumPy Compatibility API
113
114
Complete NumPy-compatible array operations including creation, manipulation, mathematical functions, linear algebra, and reductions. JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation and automatic differentiation.
115
116
```python { .api }
117
# Array creation
118
def array(object, dtype=None, **kwargs) -> Array: ...
119
def zeros(shape, dtype=None) -> Array: ...
120
def ones(shape, dtype=None) -> Array: ...
121
def arange(start, stop=None, step=None, dtype=None) -> Array: ...
122
123
# Mathematical operations
124
def sum(a, axis=None, **kwargs) -> Array: ...
125
def mean(a, axis=None, **kwargs) -> Array: ...
126
def dot(a, b) -> Array: ...
127
def matmul(x1, x2) -> Array: ...
128
```
129
130
[NumPy Compatibility](./numpy-compatibility.md)
131
132
### Neural Network Functions
133
134
Activation functions, initializers, and neural network utilities commonly used in machine learning. Includes all standard activations like ReLU, sigmoid, softmax, and modern variants like GELU, Swish, and attention mechanisms.
135
136
```python { .api }
137
def relu(x) -> Array: ...
138
def sigmoid(x) -> Array: ...
139
def softmax(x, axis=-1) -> Array: ...
140
def gelu(x, approximate=True) -> Array: ...
141
def silu(x) -> Array: ...
142
def one_hot(x, num_classes, **kwargs) -> Array: ...
143
def dot_product_attention(query, key, value, **kwargs) -> Array: ...
144
```
145
146
[Neural Networks](./neural-networks.md)
147
148
### Random Number Generation
149
150
Functional pseudo-random number generation with explicit key management. JAX uses a functional approach to random numbers that enables reproducibility, parallelization, and vectorization.
151
152
```python { .api }
153
def key(seed: int) -> Array: ...
154
def split(key: Array, num: int = 2) -> Array: ...
155
def normal(key: Array, shape=(), dtype=float) -> Array: ...
156
def uniform(key: Array, shape=(), minval=0.0, maxval=1.0) -> Array: ...
157
def categorical(key: Array, logits, **kwargs) -> Array: ...
158
def choice(key: Array, a, **kwargs) -> Array: ...
159
```
160
161
[Random Numbers](./random-numbers.md)
162
163
### Low-Level Operations
164
165
Direct XLA operations and primitives for high-performance computing. These provide the building blocks for JAX's higher-level operations and enable custom operations and optimizations.
166
167
```python { .api }
168
def add(x, y) -> Array: ...
169
def mul(x, y) -> Array: ...
170
def dot_general(lhs, rhs, dimension_numbers, **kwargs) -> Array: ...
171
def conv_general_dilated(lhs, rhs, **kwargs) -> Array: ...
172
def reduce_sum(operand, axes) -> Array: ...
173
def cond(pred, true_fun, false_fun, *operands) -> Any: ...
174
def while_loop(cond_fun, body_fun, init_val) -> Any: ...
175
def scan(f, init, xs, **kwargs) -> tuple[Any, Array]: ...
176
```
177
178
[Low-Level Operations](./low-level-ops.md)
179
180
### SciPy Compatibility
181
182
SciPy-compatible functions for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. Provides a familiar interface for scientific Python users.
183
184
```python { .api }
185
# Linear algebra (jax.scipy.linalg)
186
def solve(a, b) -> Array: ...
187
def eig(a, **kwargs) -> tuple[Array, Array]: ...
188
def svd(a, **kwargs) -> tuple[Array, Array, Array]: ...
189
190
# Special functions (jax.scipy.special)
191
def logsumexp(a, **kwargs) -> Array: ...
192
def erf(x) -> Array: ...
193
def gamma(x) -> Array: ...
194
195
# Statistics (jax.scipy.stats)
196
def norm.pdf(x, loc=0, scale=1) -> Array: ...
197
def multivariate_normal.pdf(x, mean, cov) -> Array: ...
198
```
199
200
[SciPy Compatibility](./scipy-compatibility.md)
201
202
### Tree Operations
203
204
Utilities for working with PyTrees (nested Python structures containing arrays). Essential for handling complex data structures in functional programming patterns and neural network parameters.
205
206
```python { .api }
207
def tree_map(f, tree, *rest) -> Any: ...
208
def tree_reduce(function, tree, **kwargs) -> Any: ...
209
def tree_flatten(tree) -> tuple[list, Any]: ...
210
def tree_unflatten(treedef, leaves) -> Any: ...
211
def tree_leaves(tree) -> list: ...
212
def tree_structure(tree) -> Any: ...
213
```
214
215
[Tree Operations](./tree-operations.md)
216
217
### Device and Memory Management
218
219
Device placement, memory management, and distributed computing primitives. Enables efficient use of accelerators and scaling across multiple devices.
220
221
```python { .api }
222
def devices() -> list[Device]: ...
223
def device_put(x, device=None) -> Array: ...
224
def device_get(x) -> Any: ...
225
class NamedSharding: ...
226
def make_mesh(*mesh_axes, axis_names=None) -> Mesh: ...
227
def shard_map(f, mesh, in_specs, out_specs, **kwargs) -> Callable: ...
228
```
229
230
[Device and Memory Management](./device-memory.md)
231
232
### Experimental Features
233
234
Cutting-edge and experimental JAX features including new APIs, performance optimizations, and research capabilities. These features may change in future versions.
235
236
```python { .api }
237
def io_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...
238
def enable_x64(enable=True) -> None: ...
239
class MutableArray: ...
240
def saved_input_vjp(f, *primals) -> tuple[Any, Callable]: ...
241
```
242
243
[Experimental Features](./experimental.md)
244
245
## Core Types
246
247
```python { .api }
248
class Array:
249
"""JAX array type for numerical computing."""
250
shape: tuple[int, ...]
251
dtype: numpy.dtype
252
size: int
253
ndim: int
254
255
def __array__(self) -> numpy.ndarray: ...
256
def __getitem__(self, key) -> Array: ...
257
def astype(self, dtype) -> Array: ...
258
def reshape(self, *shape) -> Array: ...
259
def transpose(self, *axes) -> Array: ...
260
261
class Device:
262
"""Device abstraction for accelerators."""
263
platform: str
264
device_kind: str
265
id: int
266
host_id: int
267
268
class ShapeDtypeStruct:
269
"""Shape and dtype structure for abstract evaluation."""
270
shape: tuple[int, ...]
271
dtype: numpy.dtype
272
273
def __init__(self, shape, dtype): ...
274
275
PRNGKeyArray = Array # Type alias for PRNG keys
276
```
277
278
## Configuration and Debugging
279
280
```python { .api }
281
# Configuration flags
282
jax.config.update('jax_enable_x64', True) # Enable 64-bit precision
283
jax.config.update('jax_debug_nans', True) # Debug NaN values
284
jax.config.update('jax_debug_infs', True) # Debug Inf values
285
jax.config.update('jax_platform_name', 'cpu') # Force platform
286
jax.config.update('jax_default_device', device) # Set default device
287
jax.config.update('jax_compilation_cache_dir', '/path/to/cache') # Cache directory
288
jax.config.update('jax_disable_jit', True) # Disable JIT globally
289
jax.config.update('jax_log_compiles', True) # Log compilation events
290
291
# Core utilities and debugging
292
def typeof(x) -> Any: ...
293
def live_arrays() -> list[Array]: ...
294
def clear_caches() -> None: ...
295
def make_jaxpr(fun) -> Callable: ...
296
def eval_shape(fun, *args, **kwargs) -> Any: ...
297
def print_environment_info() -> None: ...
298
def ensure_compile_time_eval() -> None: ...
299
def pure_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...
300
def effects_barrier() -> None: ...
301
def named_call(f, *, name: str) -> Callable: ...
302
def named_scope(name: str): ...
303
def disable_jit(disable: bool = True): ...
304
305
# Memory and performance utilities
306
def device_count_per_host() -> int: ...
307
def host_callback(callback, result_shape, *args, **kwargs) -> Any: ...
308
def make_mesh(*mesh_axes, axis_names=None) -> Any: ...
309
def with_sharding_constraint(x, constraint) -> Array: ...
310
311
# Advanced debugging
312
def debug_print(fmt: str, *args) -> None: ...
313
def debug_callback(callback, *args) -> None: ...
314
def debug_key_reuse(enable: bool = True) -> None: ...
315
```