0
# Sharding and Distribution
1
2
Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.
3
4
## Capabilities
5
6
### Sharding Base Classes
7
8
Core sharding interfaces and implementations for different distribution strategies.
9
10
```python { .api }
11
class Sharding:
12
"""Base class for all sharding implementations."""
13
14
class NamedSharding(Sharding):
15
"""Sharding with named mesh and partition specifications."""
16
17
def __init__(
18
self,
19
mesh: Any,
20
spec: Any,
21
*,
22
memory_kind: str | None = None,
23
_logical_device_ids: tuple[int, ...] | None = None,
24
): ...
25
26
mesh: Any
27
spec: Any
28
_memory_kind: str | None
29
_internal_device_list: DeviceList
30
_logical_device_ids: tuple[int, ...] | None
31
32
class SingleDeviceSharding(Sharding):
33
"""Sharding for single device placement."""
34
35
def __init__(self, device: Device, *, memory_kind: str | None = None): ...
36
37
_device: Device
38
_memory_kind: str | None
39
_internal_device_list: DeviceList
40
41
class PmapSharding(Sharding):
42
"""Sharding for pmap-style parallelism."""
43
44
def __init__(
45
self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec
46
): ...
47
48
devices: list[Any]
49
sharding_spec: pmap_lib.ShardingSpec
50
_internal_device_list: DeviceList
51
52
class GSPMDSharding(Sharding):
53
"""GSPMD (General SPMD) sharding implementation."""
54
55
def __init__(
56
self,
57
devices: Sequence[Device],
58
op_sharding: OpSharding | HloSharding,
59
*,
60
memory_kind: str | None = None,
61
_device_list: DeviceList | None = None,
62
): ...
63
64
_devices: tuple[Device, ...]
65
_hlo_sharding: HloSharding
66
_memory_kind: str | None
67
_internal_device_list: DeviceList
68
```
69
70
### HLO Sharding
71
72
Low-level HLO sharding specifications for fine-grained control over data distribution.
73
74
```python { .api }
75
class HloSharding:
76
"""HLO-level sharding specification."""
77
78
@staticmethod
79
def from_proto(proto: OpSharding) -> HloSharding:
80
"""Create HloSharding from OpSharding proto."""
81
82
@staticmethod
83
def from_string(sharding: str) -> HloSharding:
84
"""Create HloSharding from string representation."""
85
86
@staticmethod
87
def tuple_sharding(
88
shape: Shape, shardings: Sequence[HloSharding]
89
) -> HloSharding:
90
"""Create tuple sharding from component shardings."""
91
92
@staticmethod
93
def iota_tile(
94
dims: Sequence[int],
95
reshape_dims: Sequence[int],
96
transpose_perm: Sequence[int],
97
subgroup_types: Sequence[OpSharding_Type],
98
) -> HloSharding:
99
"""Create iota-based tiled sharding."""
100
101
@staticmethod
102
def replicate() -> HloSharding:
103
"""Create replicated sharding (data copied to all devices)."""
104
105
@staticmethod
106
def manual() -> HloSharding:
107
"""Create manual sharding (user-controlled placement)."""
108
109
@staticmethod
110
def unknown() -> HloSharding:
111
"""Create unknown sharding (to be inferred)."""
112
113
def is_replicated(self) -> bool:
114
"""Check if sharding is replicated."""
115
116
def is_manual(self) -> bool:
117
"""Check if sharding is manual."""
118
119
def is_unknown(self) -> bool:
120
"""Check if sharding is unknown."""
121
122
def is_tiled(self) -> bool:
123
"""Check if sharding is tiled."""
124
125
def is_maximal(self) -> bool:
126
"""Check if sharding is maximal (single device)."""
127
128
def num_devices(self) -> int:
129
"""Get number of devices in sharding."""
130
131
def tuple_elements(self) -> list[HloSharding]:
132
"""Get tuple element shardings."""
133
134
def tile_assignment_dimensions(self) -> Sequence[int]:
135
"""Get tile assignment dimensions."""
136
137
def tile_assignment_devices(self) -> Sequence[int]:
138
"""Get tile assignment device IDs."""
139
140
def to_proto(self) -> OpSharding:
141
"""Convert to OpSharding proto."""
142
```
143
144
### Operation Sharding
145
146
Protocol buffer-based sharding specifications for XLA operations.
147
148
```python { .api }
149
class OpSharding_Type(enum.IntEnum):
150
REPLICATED = ...
151
MAXIMAL = ...
152
TUPLE = ...
153
OTHER = ...
154
MANUAL = ...
155
UNKNOWN = ...
156
157
class OpSharding:
158
"""Operation sharding specification."""
159
160
Type: type[OpSharding_Type]
161
type: OpSharding_Type
162
replicate_on_last_tile_dim: bool
163
last_tile_dims: Sequence[OpSharding_Type]
164
tile_assignment_dimensions: Sequence[int]
165
tile_assignment_devices: Sequence[int]
166
iota_reshape_dims: Sequence[int]
167
iota_transpose_perm: Sequence[int]
168
tuple_shardings: Sequence[OpSharding]
169
is_shard_group: bool
170
shard_group_id: int
171
shard_group_type: OpSharding_ShardGroupType
172
173
def ParseFromString(self, s: bytes) -> None:
174
"""Parse from serialized bytes."""
175
176
def SerializeToString(self) -> bytes:
177
"""Serialize to bytes."""
178
179
def clone(self) -> OpSharding:
180
"""Create a copy of this sharding."""
181
```
182
183
### Partition Specifications
184
185
Utilities for specifying how arrays should be partitioned across device meshes.
186
187
```python { .api }
188
class PartitionSpec:
189
"""Specification for how to partition arrays."""
190
191
def __init__(self, *partitions, unreduced: Set[Any] | None = None): ...
192
193
def __hash__(self): ...
194
def __eq__(self, other): ...
195
196
class UnconstrainedSingleton:
197
"""Singleton representing unconstrained partitioning."""
198
199
def __repr__(self) -> str: ...
200
def __reduce__(self) -> Any: ...
201
202
UNCONSTRAINED_PARTITION: UnconstrainedSingleton
203
204
def canonicalize_partition(partition: Any) -> Any:
205
"""Canonicalize partition specification."""
206
```
207
208
## Usage Examples
209
210
### Basic Sharding Setup
211
212
```python
213
from jaxlib import xla_client
214
import numpy as np
215
216
# Create client with multiple devices
217
client = xla_client.make_cpu_client()
218
devices = client.local_devices()
219
220
if len(devices) >= 2:
221
# Create single device sharding
222
single_sharding = xla_client.SingleDeviceSharding(devices[0])
223
224
# Create GSPMD sharding for distribution
225
# First create OpSharding for 2-device split
226
op_sharding = xla_client.OpSharding()
227
op_sharding.type = xla_client.OpSharding_Type.OTHER
228
op_sharding.tile_assignment_dimensions = [2, 1] # Split first dimension
229
op_sharding.tile_assignment_devices = [0, 1] # Use devices 0 and 1
230
231
gspmd_sharding = xla_client.GSPMDSharding(
232
devices[:2],
233
op_sharding
234
)
235
236
print(f"GSPMD devices: {gspmd_sharding._devices}")
237
print(f"Number of devices: {gspmd_sharding._hlo_sharding.num_devices()}")
238
```
239
240
### HLO Sharding Operations
241
242
```python
243
from jaxlib import xla_client
244
245
# Create different types of HLO shardings
246
replicated = xla_client.HloSharding.replicate()
247
manual = xla_client.HloSharding.manual()
248
unknown = xla_client.HloSharding.unknown()
249
250
print(f"Replicated: {replicated.is_replicated()}")
251
print(f"Manual: {manual.is_manual()}")
252
print(f"Unknown: {unknown.is_unknown()}")
253
254
# Create sharding from string representation
255
sharding_str = "{devices=[2,1]0,1}"
256
string_sharding = xla_client.HloSharding.from_string(sharding_str)
257
print(f"Devices in sharding: {string_sharding.num_devices()}")
258
print(f"Is tiled: {string_sharding.is_tiled()}")
259
```
260
261
### Partition Specifications
262
263
```python
264
from jaxlib import xla_client
265
266
# Create partition specifications
267
spec1 = xla_client.PartitionSpec('data') # Partition along 'data' axis
268
spec2 = xla_client.PartitionSpec('batch', 'model') # Partition along two axes
269
spec3 = xla_client.PartitionSpec(None, 'data') # No partition on first axis
270
271
# Use unconstrained partition
272
unconstrained = xla_client.UNCONSTRAINED_PARTITION
273
print(f"Unconstrained: {unconstrained}")
274
275
# Canonicalize partition specs
276
canonical = xla_client.canonicalize_partition(('data', None))
277
print(f"Canonical partition: {canonical}")
278
```