or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/jax@0.7.x

To install, run

npx @tessl/cli install tessl/pypi-jax@0.7.0

0

# JAX

1

2

JAX is a NumPy-compatible library that provides composable transformations of Python+NumPy programs: differentiate, compile, and transform Numpy code. JAX brings together a powerful ecosystem of program transformations including automatic differentiation (grad), just-in-time compilation (jit), vectorization (vmap), and parallelization (pmap) with support for CPUs, GPUs, and TPUs.

3

4

## Package Information

5

6

- **Package Name**: jax

7

- **Language**: Python

8

- **Installation**: `pip install jax[cpu]` (CPU) or `pip install jax[cuda12]` (GPU)

9

10

## Core Imports

11

12

```python

13

import jax

14

import jax.numpy as jnp

15

from jax import grad, jit, vmap, pmap

16

```

17

18

Import specific transformations:

19

20

```python

21

from jax import (

22

grad, jit, vmap, pmap, jacfwd, jacrev,

23

hessian, value_and_grad, checkpoint

24

)

25

```

26

27

Import array types and devices:

28

29

```python

30

from jax import Array, Device

31

import jax.numpy as jnp

32

import jax.random as jr

33

import jax.lax as lax

34

import jax.scipy as jsp

35

import jax.nn as jnn

36

import jax.tree as tree

37

```

38

39

## Basic Usage

40

41

```python

42

import jax

43

import jax.numpy as jnp

44

from jax import grad, jit, vmap

45

46

# NumPy-compatible arrays and operations

47

x = jnp.array([1.0, 2.0, 3.0, 4.0])

48

y = jnp.sum(x ** 2) # JAX arrays work like NumPy

49

50

# Automatic differentiation

51

def loss_fn(params, x, y):

52

pred = params[0] * x + params[1]

53

return jnp.mean((pred - y) ** 2)

54

55

# Compute gradient of loss function

56

grad_fn = grad(loss_fn)

57

params = jnp.array([0.5, 0.1])

58

gradients = grad_fn(params, x, y)

59

60

# Just-in-time compilation for performance

61

@jit

62

def fast_function(x):

63

return jnp.sum(x ** 2) + jnp.sin(x).sum()

64

65

result = fast_function(x)

66

67

# Vectorization across batch dimension

68

@vmap

69

def process_batch(single_input):

70

return single_input ** 2 + jnp.sin(single_input)

71

72

batch_data = jnp.array([[1, 2], [3, 4], [5, 6]])

73

batch_result = process_batch(batch_data)

74

75

# Random number generation

76

key = jax.random.key(42)

77

random_data = jax.random.normal(key, (10, 5))

78

79

# Device management

80

print(f"Available devices: {jax.devices()}")

81

array_on_gpu = jax.device_put(x, jax.devices()[0])

82

```

83

84

## Architecture

85

86

JAX's power comes from its composable function transformations that can be applied to pure Python functions:

87

88

- **Pure Functions**: JAX transformations require functions to be functionally pure (no side effects)

89

- **Function Transformations**: grad, jit, vmap, pmap can be arbitrarily composed

90

- **XLA Compilation**: Just-in-time compilation to optimized accelerator code

91

- **Array Programming**: NumPy-compatible array operations with immutable semantics

92

- **Device Model**: Transparent execution across CPU, GPU, and TPU with explicit device management

93

94

The composability enables powerful patterns like `jit(grad(loss_fn))` or `vmap(grad(per_example_loss))`.

95

96

## Capabilities

97

98

### Core Program Transformations

99

100

The fundamental JAX transformations that enable automatic differentiation, compilation, vectorization, and parallelization. These transformations are the core of JAX's power and can be arbitrarily composed.

101

102

```python { .api }

103

def jit(fun: Callable, **kwargs) -> Callable: ...

104

def grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...

105

def vmap(fun: Callable, in_axes=0, out_axes=0, **kwargs) -> Callable: ...

106

def pmap(fun: Callable, axis_name=None, **kwargs) -> Callable: ...

107

def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, **kwargs) -> Callable: ...

108

```

109

110

[Core Transformations](./core-transformations.md)

111

112

### NumPy Compatibility API

113

114

Complete NumPy-compatible array operations including creation, manipulation, mathematical functions, linear algebra, and reductions. JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation and automatic differentiation.

115

116

```python { .api }

117

# Array creation

118

def array(object, dtype=None, **kwargs) -> Array: ...

119

def zeros(shape, dtype=None) -> Array: ...

120

def ones(shape, dtype=None) -> Array: ...

121

def arange(start, stop=None, step=None, dtype=None) -> Array: ...

122

123

# Mathematical operations

124

def sum(a, axis=None, **kwargs) -> Array: ...

125

def mean(a, axis=None, **kwargs) -> Array: ...

126

def dot(a, b) -> Array: ...

127

def matmul(x1, x2) -> Array: ...

128

```

129

130

[NumPy Compatibility](./numpy-compatibility.md)

131

132

### Neural Network Functions

133

134

Activation functions, initializers, and neural network utilities commonly used in machine learning. Includes all standard activations like ReLU, sigmoid, softmax, and modern variants like GELU, Swish, and attention mechanisms.

135

136

```python { .api }

137

def relu(x) -> Array: ...

138

def sigmoid(x) -> Array: ...

139

def softmax(x, axis=-1) -> Array: ...

140

def gelu(x, approximate=True) -> Array: ...

141

def silu(x) -> Array: ...

142

def one_hot(x, num_classes, **kwargs) -> Array: ...

143

def dot_product_attention(query, key, value, **kwargs) -> Array: ...

144

```

145

146

[Neural Networks](./neural-networks.md)

147

148

### Random Number Generation

149

150

Functional pseudo-random number generation with explicit key management. JAX uses a functional approach to random numbers that enables reproducibility, parallelization, and vectorization.

151

152

```python { .api }

153

def key(seed: int) -> Array: ...

154

def split(key: Array, num: int = 2) -> Array: ...

155

def normal(key: Array, shape=(), dtype=float) -> Array: ...

156

def uniform(key: Array, shape=(), minval=0.0, maxval=1.0) -> Array: ...

157

def categorical(key: Array, logits, **kwargs) -> Array: ...

158

def choice(key: Array, a, **kwargs) -> Array: ...

159

```

160

161

[Random Numbers](./random-numbers.md)

162

163

### Low-Level Operations

164

165

Direct XLA operations and primitives for high-performance computing. These provide the building blocks for JAX's higher-level operations and enable custom operations and optimizations.

166

167

```python { .api }

168

def add(x, y) -> Array: ...

169

def mul(x, y) -> Array: ...

170

def dot_general(lhs, rhs, dimension_numbers, **kwargs) -> Array: ...

171

def conv_general_dilated(lhs, rhs, **kwargs) -> Array: ...

172

def reduce_sum(operand, axes) -> Array: ...

173

def cond(pred, true_fun, false_fun, *operands) -> Any: ...

174

def while_loop(cond_fun, body_fun, init_val) -> Any: ...

175

def scan(f, init, xs, **kwargs) -> tuple[Any, Array]: ...

176

```

177

178

[Low-Level Operations](./low-level-ops.md)

179

180

### SciPy Compatibility

181

182

SciPy-compatible functions for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. Provides a familiar interface for scientific Python users.

183

184

```python { .api }

185

# Linear algebra (jax.scipy.linalg)

186

def solve(a, b) -> Array: ...

187

def eig(a, **kwargs) -> tuple[Array, Array]: ...

188

def svd(a, **kwargs) -> tuple[Array, Array, Array]: ...

189

190

# Special functions (jax.scipy.special)

191

def logsumexp(a, **kwargs) -> Array: ...

192

def erf(x) -> Array: ...

193

def gamma(x) -> Array: ...

194

195

# Statistics (jax.scipy.stats)

196

def norm.pdf(x, loc=0, scale=1) -> Array: ...

197

def multivariate_normal.pdf(x, mean, cov) -> Array: ...

198

```

199

200

[SciPy Compatibility](./scipy-compatibility.md)

201

202

### Tree Operations

203

204

Utilities for working with PyTrees (nested Python structures containing arrays). Essential for handling complex data structures in functional programming patterns and neural network parameters.

205

206

```python { .api }

207

def tree_map(f, tree, *rest) -> Any: ...

208

def tree_reduce(function, tree, **kwargs) -> Any: ...

209

def tree_flatten(tree) -> tuple[list, Any]: ...

210

def tree_unflatten(treedef, leaves) -> Any: ...

211

def tree_leaves(tree) -> list: ...

212

def tree_structure(tree) -> Any: ...

213

```

214

215

[Tree Operations](./tree-operations.md)

216

217

### Device and Memory Management

218

219

Device placement, memory management, and distributed computing primitives. Enables efficient use of accelerators and scaling across multiple devices.

220

221

```python { .api }

222

def devices() -> list[Device]: ...

223

def device_put(x, device=None) -> Array: ...

224

def device_get(x) -> Any: ...

225

class NamedSharding: ...

226

def make_mesh(*mesh_axes, axis_names=None) -> Mesh: ...

227

def shard_map(f, mesh, in_specs, out_specs, **kwargs) -> Callable: ...

228

```

229

230

[Device and Memory Management](./device-memory.md)

231

232

### Experimental Features

233

234

Cutting-edge and experimental JAX features including new APIs, performance optimizations, and research capabilities. These features may change in future versions.

235

236

```python { .api }

237

def io_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...

238

def enable_x64(enable=True) -> None: ...

239

class MutableArray: ...

240

def saved_input_vjp(f, *primals) -> tuple[Any, Callable]: ...

241

```

242

243

[Experimental Features](./experimental.md)

244

245

## Core Types

246

247

```python { .api }

248

class Array:

249

"""JAX array type for numerical computing."""

250

shape: tuple[int, ...]

251

dtype: numpy.dtype

252

size: int

253

ndim: int

254

255

def __array__(self) -> numpy.ndarray: ...

256

def __getitem__(self, key) -> Array: ...

257

def astype(self, dtype) -> Array: ...

258

def reshape(self, *shape) -> Array: ...

259

def transpose(self, *axes) -> Array: ...

260

261

class Device:

262

"""Device abstraction for accelerators."""

263

platform: str

264

device_kind: str

265

id: int

266

host_id: int

267

268

class ShapeDtypeStruct:

269

"""Shape and dtype structure for abstract evaluation."""

270

shape: tuple[int, ...]

271

dtype: numpy.dtype

272

273

def __init__(self, shape, dtype): ...

274

275

PRNGKeyArray = Array # Type alias for PRNG keys

276

```

277

278

## Configuration and Debugging

279

280

```python { .api }

281

# Configuration flags

282

jax.config.update('jax_enable_x64', True) # Enable 64-bit precision

283

jax.config.update('jax_debug_nans', True) # Debug NaN values

284

jax.config.update('jax_debug_infs', True) # Debug Inf values

285

jax.config.update('jax_platform_name', 'cpu') # Force platform

286

jax.config.update('jax_default_device', device) # Set default device

287

jax.config.update('jax_compilation_cache_dir', '/path/to/cache') # Cache directory

288

jax.config.update('jax_disable_jit', True) # Disable JIT globally

289

jax.config.update('jax_log_compiles', True) # Log compilation events

290

291

# Core utilities and debugging

292

def typeof(x) -> Any: ...

293

def live_arrays() -> list[Array]: ...

294

def clear_caches() -> None: ...

295

def make_jaxpr(fun) -> Callable: ...

296

def eval_shape(fun, *args, **kwargs) -> Any: ...

297

def print_environment_info() -> None: ...

298

def ensure_compile_time_eval() -> None: ...

299

def pure_callback(callback, result_shape_dtypes, *args, **kwargs) -> Any: ...

300

def effects_barrier() -> None: ...

301

def named_call(f, *, name: str) -> Callable: ...

302

def named_scope(name: str): ...

303

def disable_jit(disable: bool = True): ...

304

305

# Memory and performance utilities

306

def device_count_per_host() -> int: ...

307

def host_callback(callback, result_shape, *args, **kwargs) -> Any: ...

308

def make_mesh(*mesh_axes, axis_names=None) -> Any: ...

309

def with_sharding_constraint(x, constraint) -> Array: ...

310

311

# Advanced debugging

312

def debug_print(fmt: str, *args) -> None: ...

313

def debug_callback(callback, *args) -> None: ...

314

def debug_key_reuse(enable: bool = True) -> None: ...

315

```