or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

device-memory.mddocs/

0

# Device and Memory Management

1

2

JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.

3

4

## Core Imports

5

6

```python

7

import jax

8

from jax import devices, device_put, make_mesh

9

from jax.sharding import NamedSharding, PartitionSpec as P

10

```

11

12

## Capabilities

13

14

### Device Discovery and Information

15

16

Query available devices and their properties for computation placement and resource management.

17

18

```python { .api }

19

def devices(backend=None) -> list[Device]:

20

"""

21

Get list of all available devices.

22

23

Args:

24

backend: Optional backend name ('cpu', 'gpu', 'tpu')

25

26

Returns:

27

List of available Device objects

28

"""

29

30

def local_devices(process_index=None, backend=None) -> list[Device]:

31

"""

32

Get list of devices local to current process.

33

34

Args:

35

process_index: Process index (None for current process)

36

backend: Optional backend name

37

38

Returns:

39

List of local Device objects

40

"""

41

42

def device_count(backend=None) -> int:

43

"""

44

Get total number of devices across all processes.

45

46

Args:

47

backend: Optional backend name

48

49

Returns:

50

Total device count

51

"""

52

53

def local_device_count(backend=None) -> int:

54

"""

55

Get number of devices on current process.

56

57

Args:

58

backend: Optional backend name

59

60

Returns:

61

Local device count

62

"""

63

64

def host_count(backend=None) -> int:

65

"""

66

Get number of hosts in distributed computation.

67

68

Args:

69

backend: Optional backend name

70

71

Returns:

72

Host count

73

"""

74

75

def host_id(backend=None) -> int:

76

"""

77

Get ID of current host.

78

79

Args:

80

backend: Optional backend name

81

82

Returns:

83

Current host ID

84

"""

85

86

def host_ids(backend=None) -> list[int]:

87

"""

88

Get list of all host IDs.

89

90

Args:

91

backend: Optional backend name

92

93

Returns:

94

List of host IDs

95

"""

96

97

def process_count(backend=None) -> int:

98

"""

99

Get number of processes in distributed computation.

100

101

Args:

102

backend: Optional backend name

103

104

Returns:

105

Process count

106

"""

107

108

def process_index(backend=None) -> int:

109

"""

110

Get index of current process.

111

112

Args:

113

backend: Optional backend name

114

115

Returns:

116

Current process index

117

"""

118

119

def process_indices(backend=None) -> list[int]:

120

"""

121

Get list of all process indices.

122

123

Args:

124

backend: Optional backend name

125

126

Returns:

127

List of process indices

128

"""

129

130

def default_backend() -> str:

131

"""

132

Get name of default backend.

133

134

Returns:

135

Default backend name string

136

"""

137

```

138

139

### Device Placement and Data Movement

140

141

Control where computations run and move data between devices and host memory.

142

143

```python { .api }

144

def device_put(x, device=None, src=None) -> Array:

145

"""

146

Move array to specified device.

147

148

Args:

149

x: Array or array-like object to move

150

device: Target device (None for default device)

151

src: Source device for the transfer

152

153

Returns:

154

Array placed on target device

155

"""

156

157

def device_put_sharded(

158

sharded_values: list,

159

devices: list[Device],

160

indices=None

161

) -> Array:

162

"""

163

Create sharded array from per-device values.

164

165

Args:

166

sharded_values: List of arrays, one per device

167

devices: List of target devices

168

indices: Optional sharding indices

169

170

Returns:

171

Distributed array sharded across devices

172

"""

173

174

def device_put_replicated(x, devices: list[Device]) -> Array:

175

"""

176

Replicate array across multiple devices.

177

178

Args:

179

x: Array to replicate

180

devices: List of target devices

181

182

Returns:

183

Array replicated across all specified devices

184

"""

185

186

def device_get(x) -> Any:

187

"""

188

Move array from device to host memory as NumPy array.

189

190

Args:

191

x: Array to move to host

192

193

Returns:

194

NumPy array in host memory

195

"""

196

197

def copy_to_host_async(x) -> Any:

198

"""

199

Asynchronously copy array to host memory.

200

201

Args:

202

x: Array to copy

203

204

Returns:

205

Future-like object for async copy

206

"""

207

208

def block_until_ready(x) -> Array:

209

"""

210

Block until array computation is complete and ready.

211

212

Args:

213

x: Array to wait for

214

215

Returns:

216

The same array, guaranteed to be ready

217

"""

218

```

219

220

Usage examples:

221

```python

222

# Check available devices

223

all_devices = jax.devices()

224

print(f"Available devices: {all_devices}")

225

print(f"Device count: {jax.device_count()}")

226

227

# Move data to specific device

228

cpu_data = jnp.array([1, 2, 3, 4])

229

if jax.devices('gpu'):

230

gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])

231

print(f"Data is on: {gpu_data.device()}")

232

233

# Move back to host

234

host_data = jax.device_get(gpu_data) # Returns NumPy array

235

236

# Explicit device placement in computations

237

with jax.default_device(jax.devices('cpu')[0]):

238

cpu_result = jnp.sum(jnp.array([1, 2, 3]))

239

```

240

241

### Sharding and Distributed Arrays

242

243

Define how arrays are distributed across multiple devices for parallel computation.

244

245

```python { .api }

246

class NamedSharding:

247

"""

248

Sharding specification using named mesh axes.

249

250

Defines how arrays are partitioned across devices using logical axis names.

251

"""

252

253

def __init__(self, mesh, spec):

254

"""

255

Create named sharding specification.

256

257

Args:

258

mesh: Device mesh with named axes

259

spec: Partition specification (PartitionSpec)

260

"""

261

self.mesh = mesh

262

self.spec = spec

263

264

class PartitionSpec:

265

"""

266

Specification for how to partition array dimensions across mesh axes.

267

268

Use P(axis_names...) to create partition specifications.

269

"""

270

pass

271

272

# Alias for PartitionSpec

273

P = PartitionSpec

274

275

def make_mesh(mesh_shape, axis_names) -> Mesh:

276

"""

277

Create device mesh for distributed computation.

278

279

Args:

280

mesh_shape: Shape of device mesh (tuple of integers)

281

axis_names: Names for mesh axes (tuple of strings)

282

283

Returns:

284

Mesh object representing device layout

285

"""

286

287

class Mesh:

288

"""Device mesh for distributed computation."""

289

devices: Array # Device array in mesh shape

290

axis_names: tuple[str, ...] # Names of mesh axes

291

292

@property

293

def shape(self) -> dict[str, int]:

294

"""Dictionary mapping axis names to sizes."""

295

296

@property

297

def size(self) -> int:

298

"""Total number of devices in mesh."""

299

300

def make_array_from_single_device_arrays(

301

arrays: list[Array],

302

sharding: Sharding

303

) -> Array:

304

"""

305

Create distributed array from per-device arrays.

306

307

Args:

308

arrays: List of arrays on different devices

309

sharding: Sharding specification

310

311

Returns:

312

Distributed array with specified sharding

313

"""

314

315

def make_array_from_callback(

316

shape: tuple[int, ...],

317

sharding: Sharding,

318

data_callback: Callable

319

) -> Array:

320

"""

321

Create distributed array using callback function.

322

323

Args:

324

shape: Global array shape

325

sharding: Sharding specification

326

data_callback: Function to generate data for each shard

327

328

Returns:

329

Distributed array created from callback

330

"""

331

332

def make_array_from_process_local_data(

333

sharding: Sharding,

334

local_data: Array

335

) -> Array:

336

"""

337

Create distributed array from process-local data.

338

339

Args:

340

sharding: Sharding specification

341

local_data: Data local to current process

342

343

Returns:

344

Distributed array assembled from local data

345

"""

346

```

347

348

### Sharded Computation

349

350

Execute computations on sharded arrays with explicit control over parallelization.

351

352

```python { .api }

353

def shard_map(

354

f: Callable,

355

mesh: Mesh,

356

in_specs,

357

out_specs,

358

check_rep=True

359

) -> Callable:

360

"""

361

Transform function to operate on sharded arrays.

362

363

Args:

364

f: Function to transform

365

mesh: Device mesh for computation

366

in_specs: Input sharding specifications

367

out_specs: Output sharding specifications

368

check_rep: Whether to check for replication consistency

369

370

Returns:

371

Function that operates on globally sharded arrays

372

"""

373

374

# Alias for shard_map

375

smap = shard_map

376

377

def with_sharding_constraint(x, sharding) -> Array:

378

"""

379

Add sharding constraint to array.

380

381

Args:

382

x: Input array

383

sharding: Desired sharding specification

384

385

Returns:

386

Array with sharding constraint applied

387

"""

388

```

389

390

Usage examples:

391

```python

392

# Create 2x2 device mesh

393

devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)

394

mesh = jax.make_mesh((2, 2), ('data', 'model'))

395

396

# Define sharding specifications

397

data_sharding = NamedSharding(mesh, P('data', None)) # Shard first axis across 'data'

398

model_sharding = NamedSharding(mesh, P(None, 'model')) # Shard second axis across 'model'

399

replicated_sharding = NamedSharding(mesh, P()) # Replicated across all devices

400

401

# Create sharded arrays

402

x = jax.random.normal(jax.random.key(0), (8, 4))

403

x_sharded = jax.device_put(x, data_sharding)

404

405

weights = jax.random.normal(jax.random.key(1), (4, 8))

406

weights_sharded = jax.device_put(weights, model_sharding)

407

408

# Computation with sharded arrays automatically parallelized

409

@jax.jit

410

def matmul_fn(x, w):

411

return x @ w

412

413

result = matmul_fn(x_sharded, weights_sharded) # Automatically sharded computation

414

415

# Explicit sharding control

416

def single_device_fn(x_shard, w_shard):

417

return x_shard @ w_shard

418

419

parallel_fn = jax.shard_map(

420

single_device_fn,

421

mesh=mesh,

422

in_specs=(P('data', None), P(None, 'model')),

423

out_specs=P('data', 'model')

424

)

425

426

result = parallel_fn(x_sharded, weights_sharded)

427

```

428

429

### Memory Management

430

431

Control memory usage and optimize performance through explicit memory management.

432

433

```python { .api }

434

def live_arrays() -> list[Array]:

435

"""

436

Get list of arrays currently alive in memory.

437

438

Returns:

439

List of live Array objects

440

"""

441

442

def clear_caches() -> None:

443

"""

444

Clear JAX's internal caches to free memory.

445

446

Clears JIT compilation cache, device buffer cache, and other internal caches.

447

"""

448

```

449

450

### Configuration and Backend Management

451

452

Configure device behavior and backend selection.

453

454

```python { .api }

455

# Configuration through jax.config

456

jax.config.update('jax_platform_name', 'cpu') # Force CPU backend

457

jax.config.update('jax_platform_name', 'gpu') # Force GPU backend

458

jax.config.update('jax_platform_name', 'tpu') # Force TPU backend

459

460

# Transfer guards to catch unintentional device transfers

461

jax.config.update('jax_transfer_guard', 'allow') # Default: allow all transfers

462

jax.config.update('jax_transfer_guard', 'log') # Log transfers

463

jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers

464

jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers

465

466

# Default device configuration

467

jax.config.update('jax_default_device', jax.devices('gpu')[0]) # Set default device

468

```

469

470

### Array and Device Properties

471

472

Inspect array placement and device properties.

473

474

```python { .api }

475

# Array device methods

476

array.device() -> Device # Get device containing array

477

array.devices() -> set[Device] # Get all devices for distributed array

478

array.sharding -> Sharding # Get array's sharding specification

479

array.is_fully_replicated -> bool # Check if array is replicated

480

array.is_fully_addressable -> bool # Check if array is fully addressable

481

482

# Device properties

483

class Device:

484

"""Device object representing compute accelerator."""

485

486

platform: str # Platform name ('cpu', 'gpu', 'tpu')

487

device_kind: str # Device kind string

488

id: int # Device ID within platform

489

host_id: int # Host ID containing device

490

process_index: int # Process index containing device

491

492

def __str__(self) -> str: ...

493

def __repr__(self) -> str: ...

494

```

495

496

## Advanced Usage Patterns

497

498

### Multi-Device Training

499

500

```python

501

# Setup for data-parallel training

502

def create_train_setup(num_devices):

503

# Create mesh for data parallelism

504

mesh = jax.make_mesh((num_devices,), ('batch',))

505

506

# Sharding specifications

507

batch_sharding = NamedSharding(mesh, P('batch')) # Batch dimension sharded

508

replicated_sharding = NamedSharding(mesh, P()) # Parameters replicated

509

510

return mesh, batch_sharding, replicated_sharding

511

512

def distributed_train_step(params, batch, optimizer_state):

513

# All arrays should already have appropriate sharding

514

grads = jax.grad(loss_fn)(params, batch)

515

516

# Update step automatically uses sharding from inputs

517

new_params, new_state = optimizer.update(grads, optimizer_state, params)

518

return new_params, new_state

519

520

# JIT compile with sharding

521

distributed_train_step = jax.jit(

522

distributed_train_step,

523

in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),

524

out_shardings=(replicated_sharding, replicated_sharding)

525

)

526

```

527

528

### Model Parallelism

529

530

```python

531

# Setup for model-parallel computation

532

def create_model_parallel_setup():

533

# 2D mesh: batch x model dimensions

534

mesh = jax.make_mesh((2, 4), ('batch', 'model'))

535

536

# Different sharding strategies

537

input_sharding = NamedSharding(mesh, P('batch', None))

538

weight_sharding = NamedSharding(mesh, P(None, 'model'))

539

output_sharding = NamedSharding(mesh, P('batch', 'model'))

540

541

return mesh, input_sharding, weight_sharding, output_sharding

542

543

def model_parallel_layer(x, weights):

544

# Matrix multiply with different sharding patterns

545

return x @ weights # JAX handles the communication automatically

546

547

# Shard arrays according to strategy

548

x = jax.device_put(x, input_sharding)

549

weights = jax.device_put(weights, weight_sharding)

550

result = model_parallel_layer(x, weights) # Result has output_sharding

551

```

552

553

### Memory-Efficient Inference

554

555

```python

556

def memory_efficient_inference(model_fn, large_input):

557

# Process in chunks to manage memory

558

chunk_size = 1000

559

chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]

560

561

results = []

562

for chunk in chunks:

563

# Move to device, compute, move back to host

564

device_chunk = jax.device_put(chunk)

565

device_result = model_fn(device_chunk)

566

host_result = jax.device_get(device_result)

567

results.append(host_result)

568

569

# Optional: clear caches to free memory

570

jax.clear_caches()

571

572

return jnp.concatenate(results)

573

```

574

575

### Cross-Device Communication Patterns

576

577

```python

578

# Collective operations using pmap

579

@jax.pmap

580

def allreduce_example(x):

581

# Sum across all devices

582

return jax.lax.psum(x, axis_name='batch')

583

584

@jax.pmap

585

def allgather_example(x):

586

# Gather from all devices

587

return jax.lax.all_gather(x, axis_name='batch')

588

589

# Use with replicated data

590

replicated_data = jax.device_put_replicated(data, jax.devices())

591

summed_result = allreduce_example(replicated_data)

592

gathered_result = allgather_example(replicated_data)

593

```