0
# Device and Memory Management
1
2
JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.
3
4
## Core Imports
5
6
```python
7
import jax
8
from jax import devices, device_put, make_mesh
9
from jax.sharding import NamedSharding, PartitionSpec as P
10
```
11
12
## Capabilities
13
14
### Device Discovery and Information
15
16
Query available devices and their properties for computation placement and resource management.
17
18
```python { .api }
19
def devices(backend=None) -> list[Device]:
20
"""
21
Get list of all available devices.
22
23
Args:
24
backend: Optional backend name ('cpu', 'gpu', 'tpu')
25
26
Returns:
27
List of available Device objects
28
"""
29
30
def local_devices(process_index=None, backend=None) -> list[Device]:
31
"""
32
Get list of devices local to current process.
33
34
Args:
35
process_index: Process index (None for current process)
36
backend: Optional backend name
37
38
Returns:
39
List of local Device objects
40
"""
41
42
def device_count(backend=None) -> int:
43
"""
44
Get total number of devices across all processes.
45
46
Args:
47
backend: Optional backend name
48
49
Returns:
50
Total device count
51
"""
52
53
def local_device_count(backend=None) -> int:
54
"""
55
Get number of devices on current process.
56
57
Args:
58
backend: Optional backend name
59
60
Returns:
61
Local device count
62
"""
63
64
def host_count(backend=None) -> int:
65
"""
66
Get number of hosts in distributed computation.
67
68
Args:
69
backend: Optional backend name
70
71
Returns:
72
Host count
73
"""
74
75
def host_id(backend=None) -> int:
76
"""
77
Get ID of current host.
78
79
Args:
80
backend: Optional backend name
81
82
Returns:
83
Current host ID
84
"""
85
86
def host_ids(backend=None) -> list[int]:
87
"""
88
Get list of all host IDs.
89
90
Args:
91
backend: Optional backend name
92
93
Returns:
94
List of host IDs
95
"""
96
97
def process_count(backend=None) -> int:
98
"""
99
Get number of processes in distributed computation.
100
101
Args:
102
backend: Optional backend name
103
104
Returns:
105
Process count
106
"""
107
108
def process_index(backend=None) -> int:
109
"""
110
Get index of current process.
111
112
Args:
113
backend: Optional backend name
114
115
Returns:
116
Current process index
117
"""
118
119
def process_indices(backend=None) -> list[int]:
120
"""
121
Get list of all process indices.
122
123
Args:
124
backend: Optional backend name
125
126
Returns:
127
List of process indices
128
"""
129
130
def default_backend() -> str:
131
"""
132
Get name of default backend.
133
134
Returns:
135
Default backend name string
136
"""
137
```
138
139
### Device Placement and Data Movement
140
141
Control where computations run and move data between devices and host memory.
142
143
```python { .api }
144
def device_put(x, device=None, src=None) -> Array:
145
"""
146
Move array to specified device.
147
148
Args:
149
x: Array or array-like object to move
150
device: Target device (None for default device)
151
src: Source device for the transfer
152
153
Returns:
154
Array placed on target device
155
"""
156
157
def device_put_sharded(
158
sharded_values: list,
159
devices: list[Device],
160
indices=None
161
) -> Array:
162
"""
163
Create sharded array from per-device values.
164
165
Args:
166
sharded_values: List of arrays, one per device
167
devices: List of target devices
168
indices: Optional sharding indices
169
170
Returns:
171
Distributed array sharded across devices
172
"""
173
174
def device_put_replicated(x, devices: list[Device]) -> Array:
175
"""
176
Replicate array across multiple devices.
177
178
Args:
179
x: Array to replicate
180
devices: List of target devices
181
182
Returns:
183
Array replicated across all specified devices
184
"""
185
186
def device_get(x) -> Any:
187
"""
188
Move array from device to host memory as NumPy array.
189
190
Args:
191
x: Array to move to host
192
193
Returns:
194
NumPy array in host memory
195
"""
196
197
def copy_to_host_async(x) -> Any:
198
"""
199
Asynchronously copy array to host memory.
200
201
Args:
202
x: Array to copy
203
204
Returns:
205
Future-like object for async copy
206
"""
207
208
def block_until_ready(x) -> Array:
209
"""
210
Block until array computation is complete and ready.
211
212
Args:
213
x: Array to wait for
214
215
Returns:
216
The same array, guaranteed to be ready
217
"""
218
```
219
220
Usage examples:
221
```python
222
# Check available devices
223
all_devices = jax.devices()
224
print(f"Available devices: {all_devices}")
225
print(f"Device count: {jax.device_count()}")
226
227
# Move data to specific device
228
cpu_data = jnp.array([1, 2, 3, 4])
229
if jax.devices('gpu'):
230
gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])
231
print(f"Data is on: {gpu_data.device()}")
232
233
# Move back to host
234
host_data = jax.device_get(gpu_data) # Returns NumPy array
235
236
# Explicit device placement in computations
237
with jax.default_device(jax.devices('cpu')[0]):
238
cpu_result = jnp.sum(jnp.array([1, 2, 3]))
239
```
240
241
### Sharding and Distributed Arrays
242
243
Define how arrays are distributed across multiple devices for parallel computation.
244
245
```python { .api }
246
class NamedSharding:
247
"""
248
Sharding specification using named mesh axes.
249
250
Defines how arrays are partitioned across devices using logical axis names.
251
"""
252
253
def __init__(self, mesh, spec):
254
"""
255
Create named sharding specification.
256
257
Args:
258
mesh: Device mesh with named axes
259
spec: Partition specification (PartitionSpec)
260
"""
261
self.mesh = mesh
262
self.spec = spec
263
264
class PartitionSpec:
265
"""
266
Specification for how to partition array dimensions across mesh axes.
267
268
Use P(axis_names...) to create partition specifications.
269
"""
270
pass
271
272
# Alias for PartitionSpec
273
P = PartitionSpec
274
275
def make_mesh(mesh_shape, axis_names) -> Mesh:
276
"""
277
Create device mesh for distributed computation.
278
279
Args:
280
mesh_shape: Shape of device mesh (tuple of integers)
281
axis_names: Names for mesh axes (tuple of strings)
282
283
Returns:
284
Mesh object representing device layout
285
"""
286
287
class Mesh:
288
"""Device mesh for distributed computation."""
289
devices: Array # Device array in mesh shape
290
axis_names: tuple[str, ...] # Names of mesh axes
291
292
@property
293
def shape(self) -> dict[str, int]:
294
"""Dictionary mapping axis names to sizes."""
295
296
@property
297
def size(self) -> int:
298
"""Total number of devices in mesh."""
299
300
def make_array_from_single_device_arrays(
301
arrays: list[Array],
302
sharding: Sharding
303
) -> Array:
304
"""
305
Create distributed array from per-device arrays.
306
307
Args:
308
arrays: List of arrays on different devices
309
sharding: Sharding specification
310
311
Returns:
312
Distributed array with specified sharding
313
"""
314
315
def make_array_from_callback(
316
shape: tuple[int, ...],
317
sharding: Sharding,
318
data_callback: Callable
319
) -> Array:
320
"""
321
Create distributed array using callback function.
322
323
Args:
324
shape: Global array shape
325
sharding: Sharding specification
326
data_callback: Function to generate data for each shard
327
328
Returns:
329
Distributed array created from callback
330
"""
331
332
def make_array_from_process_local_data(
333
sharding: Sharding,
334
local_data: Array
335
) -> Array:
336
"""
337
Create distributed array from process-local data.
338
339
Args:
340
sharding: Sharding specification
341
local_data: Data local to current process
342
343
Returns:
344
Distributed array assembled from local data
345
"""
346
```
347
348
### Sharded Computation
349
350
Execute computations on sharded arrays with explicit control over parallelization.
351
352
```python { .api }
353
def shard_map(
354
f: Callable,
355
mesh: Mesh,
356
in_specs,
357
out_specs,
358
check_rep=True
359
) -> Callable:
360
"""
361
Transform function to operate on sharded arrays.
362
363
Args:
364
f: Function to transform
365
mesh: Device mesh for computation
366
in_specs: Input sharding specifications
367
out_specs: Output sharding specifications
368
check_rep: Whether to check for replication consistency
369
370
Returns:
371
Function that operates on globally sharded arrays
372
"""
373
374
# Alias for shard_map
375
smap = shard_map
376
377
def with_sharding_constraint(x, sharding) -> Array:
378
"""
379
Add sharding constraint to array.
380
381
Args:
382
x: Input array
383
sharding: Desired sharding specification
384
385
Returns:
386
Array with sharding constraint applied
387
"""
388
```
389
390
Usage examples:
391
```python
392
# Create 2x2 device mesh
393
devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)
394
mesh = jax.make_mesh((2, 2), ('data', 'model'))
395
396
# Define sharding specifications
397
data_sharding = NamedSharding(mesh, P('data', None)) # Shard first axis across 'data'
398
model_sharding = NamedSharding(mesh, P(None, 'model')) # Shard second axis across 'model'
399
replicated_sharding = NamedSharding(mesh, P()) # Replicated across all devices
400
401
# Create sharded arrays
402
x = jax.random.normal(jax.random.key(0), (8, 4))
403
x_sharded = jax.device_put(x, data_sharding)
404
405
weights = jax.random.normal(jax.random.key(1), (4, 8))
406
weights_sharded = jax.device_put(weights, model_sharding)
407
408
# Computation with sharded arrays automatically parallelized
409
@jax.jit
410
def matmul_fn(x, w):
411
return x @ w
412
413
result = matmul_fn(x_sharded, weights_sharded) # Automatically sharded computation
414
415
# Explicit sharding control
416
def single_device_fn(x_shard, w_shard):
417
return x_shard @ w_shard
418
419
parallel_fn = jax.shard_map(
420
single_device_fn,
421
mesh=mesh,
422
in_specs=(P('data', None), P(None, 'model')),
423
out_specs=P('data', 'model')
424
)
425
426
result = parallel_fn(x_sharded, weights_sharded)
427
```
428
429
### Memory Management
430
431
Control memory usage and optimize performance through explicit memory management.
432
433
```python { .api }
434
def live_arrays() -> list[Array]:
435
"""
436
Get list of arrays currently alive in memory.
437
438
Returns:
439
List of live Array objects
440
"""
441
442
def clear_caches() -> None:
443
"""
444
Clear JAX's internal caches to free memory.
445
446
Clears JIT compilation cache, device buffer cache, and other internal caches.
447
"""
448
```
449
450
### Configuration and Backend Management
451
452
Configure device behavior and backend selection.
453
454
```python { .api }
455
# Configuration through jax.config
456
jax.config.update('jax_platform_name', 'cpu') # Force CPU backend
457
jax.config.update('jax_platform_name', 'gpu') # Force GPU backend
458
jax.config.update('jax_platform_name', 'tpu') # Force TPU backend
459
460
# Transfer guards to catch unintentional device transfers
461
jax.config.update('jax_transfer_guard', 'allow') # Default: allow all transfers
462
jax.config.update('jax_transfer_guard', 'log') # Log transfers
463
jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers
464
jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers
465
466
# Default device configuration
467
jax.config.update('jax_default_device', jax.devices('gpu')[0]) # Set default device
468
```
469
470
### Array and Device Properties
471
472
Inspect array placement and device properties.
473
474
```python { .api }
475
# Array device methods
476
array.device() -> Device # Get device containing array
477
array.devices() -> set[Device] # Get all devices for distributed array
478
array.sharding -> Sharding # Get array's sharding specification
479
array.is_fully_replicated -> bool # Check if array is replicated
480
array.is_fully_addressable -> bool # Check if array is fully addressable
481
482
# Device properties
483
class Device:
484
"""Device object representing compute accelerator."""
485
486
platform: str # Platform name ('cpu', 'gpu', 'tpu')
487
device_kind: str # Device kind string
488
id: int # Device ID within platform
489
host_id: int # Host ID containing device
490
process_index: int # Process index containing device
491
492
def __str__(self) -> str: ...
493
def __repr__(self) -> str: ...
494
```
495
496
## Advanced Usage Patterns
497
498
### Multi-Device Training
499
500
```python
501
# Setup for data-parallel training
502
def create_train_setup(num_devices):
503
# Create mesh for data parallelism
504
mesh = jax.make_mesh((num_devices,), ('batch',))
505
506
# Sharding specifications
507
batch_sharding = NamedSharding(mesh, P('batch')) # Batch dimension sharded
508
replicated_sharding = NamedSharding(mesh, P()) # Parameters replicated
509
510
return mesh, batch_sharding, replicated_sharding
511
512
def distributed_train_step(params, batch, optimizer_state):
513
# All arrays should already have appropriate sharding
514
grads = jax.grad(loss_fn)(params, batch)
515
516
# Update step automatically uses sharding from inputs
517
new_params, new_state = optimizer.update(grads, optimizer_state, params)
518
return new_params, new_state
519
520
# JIT compile with sharding
521
distributed_train_step = jax.jit(
522
distributed_train_step,
523
in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),
524
out_shardings=(replicated_sharding, replicated_sharding)
525
)
526
```
527
528
### Model Parallelism
529
530
```python
531
# Setup for model-parallel computation
532
def create_model_parallel_setup():
533
# 2D mesh: batch x model dimensions
534
mesh = jax.make_mesh((2, 4), ('batch', 'model'))
535
536
# Different sharding strategies
537
input_sharding = NamedSharding(mesh, P('batch', None))
538
weight_sharding = NamedSharding(mesh, P(None, 'model'))
539
output_sharding = NamedSharding(mesh, P('batch', 'model'))
540
541
return mesh, input_sharding, weight_sharding, output_sharding
542
543
def model_parallel_layer(x, weights):
544
# Matrix multiply with different sharding patterns
545
return x @ weights # JAX handles the communication automatically
546
547
# Shard arrays according to strategy
548
x = jax.device_put(x, input_sharding)
549
weights = jax.device_put(weights, weight_sharding)
550
result = model_parallel_layer(x, weights) # Result has output_sharding
551
```
552
553
### Memory-Efficient Inference
554
555
```python
556
def memory_efficient_inference(model_fn, large_input):
557
# Process in chunks to manage memory
558
chunk_size = 1000
559
chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]
560
561
results = []
562
for chunk in chunks:
563
# Move to device, compute, move back to host
564
device_chunk = jax.device_put(chunk)
565
device_result = model_fn(device_chunk)
566
host_result = jax.device_get(device_result)
567
results.append(host_result)
568
569
# Optional: clear caches to free memory
570
jax.clear_caches()
571
572
return jnp.concatenate(results)
573
```
574
575
### Cross-Device Communication Patterns
576
577
```python
578
# Collective operations using pmap
579
@jax.pmap
580
def allreduce_example(x):
581
# Sum across all devices
582
return jax.lax.psum(x, axis_name='batch')
583
584
@jax.pmap
585
def allgather_example(x):
586
# Gather from all devices
587
return jax.lax.all_gather(x, axis_name='batch')
588
589
# Use with replicated data
590
replicated_data = jax.device_put_replicated(data, jax.devices())
591
summed_result = allreduce_example(replicated_data)
592
gathered_result = allgather_example(replicated_data)
593
```