or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

sharding.mddocs/

0

# Sharding and Distribution

1

2

Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.

3

4

## Capabilities

5

6

### Sharding Base Classes

7

8

Core sharding interfaces and implementations for different distribution strategies.

9

10

```python { .api }

11

class Sharding:

12

"""Base class for all sharding implementations."""

13

14

class NamedSharding(Sharding):

15

"""Sharding with named mesh and partition specifications."""

16

17

def __init__(

18

self,

19

mesh: Any,

20

spec: Any,

21

*,

22

memory_kind: str | None = None,

23

_logical_device_ids: tuple[int, ...] | None = None,

24

): ...

25

26

mesh: Any

27

spec: Any

28

_memory_kind: str | None

29

_internal_device_list: DeviceList

30

_logical_device_ids: tuple[int, ...] | None

31

32

class SingleDeviceSharding(Sharding):

33

"""Sharding for single device placement."""

34

35

def __init__(self, device: Device, *, memory_kind: str | None = None): ...

36

37

_device: Device

38

_memory_kind: str | None

39

_internal_device_list: DeviceList

40

41

class PmapSharding(Sharding):

42

"""Sharding for pmap-style parallelism."""

43

44

def __init__(

45

self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec

46

): ...

47

48

devices: list[Any]

49

sharding_spec: pmap_lib.ShardingSpec

50

_internal_device_list: DeviceList

51

52

class GSPMDSharding(Sharding):

53

"""GSPMD (General SPMD) sharding implementation."""

54

55

def __init__(

56

self,

57

devices: Sequence[Device],

58

op_sharding: OpSharding | HloSharding,

59

*,

60

memory_kind: str | None = None,

61

_device_list: DeviceList | None = None,

62

): ...

63

64

_devices: tuple[Device, ...]

65

_hlo_sharding: HloSharding

66

_memory_kind: str | None

67

_internal_device_list: DeviceList

68

```

69

70

### HLO Sharding

71

72

Low-level HLO sharding specifications for fine-grained control over data distribution.

73

74

```python { .api }

75

class HloSharding:

76

"""HLO-level sharding specification."""

77

78

@staticmethod

79

def from_proto(proto: OpSharding) -> HloSharding:

80

"""Create HloSharding from OpSharding proto."""

81

82

@staticmethod

83

def from_string(sharding: str) -> HloSharding:

84

"""Create HloSharding from string representation."""

85

86

@staticmethod

87

def tuple_sharding(

88

shape: Shape, shardings: Sequence[HloSharding]

89

) -> HloSharding:

90

"""Create tuple sharding from component shardings."""

91

92

@staticmethod

93

def iota_tile(

94

dims: Sequence[int],

95

reshape_dims: Sequence[int],

96

transpose_perm: Sequence[int],

97

subgroup_types: Sequence[OpSharding_Type],

98

) -> HloSharding:

99

"""Create iota-based tiled sharding."""

100

101

@staticmethod

102

def replicate() -> HloSharding:

103

"""Create replicated sharding (data copied to all devices)."""

104

105

@staticmethod

106

def manual() -> HloSharding:

107

"""Create manual sharding (user-controlled placement)."""

108

109

@staticmethod

110

def unknown() -> HloSharding:

111

"""Create unknown sharding (to be inferred)."""

112

113

def is_replicated(self) -> bool:

114

"""Check if sharding is replicated."""

115

116

def is_manual(self) -> bool:

117

"""Check if sharding is manual."""

118

119

def is_unknown(self) -> bool:

120

"""Check if sharding is unknown."""

121

122

def is_tiled(self) -> bool:

123

"""Check if sharding is tiled."""

124

125

def is_maximal(self) -> bool:

126

"""Check if sharding is maximal (single device)."""

127

128

def num_devices(self) -> int:

129

"""Get number of devices in sharding."""

130

131

def tuple_elements(self) -> list[HloSharding]:

132

"""Get tuple element shardings."""

133

134

def tile_assignment_dimensions(self) -> Sequence[int]:

135

"""Get tile assignment dimensions."""

136

137

def tile_assignment_devices(self) -> Sequence[int]:

138

"""Get tile assignment device IDs."""

139

140

def to_proto(self) -> OpSharding:

141

"""Convert to OpSharding proto."""

142

```

143

144

### Operation Sharding

145

146

Protocol buffer-based sharding specifications for XLA operations.

147

148

```python { .api }

149

class OpSharding_Type(enum.IntEnum):

150

REPLICATED = ...

151

MAXIMAL = ...

152

TUPLE = ...

153

OTHER = ...

154

MANUAL = ...

155

UNKNOWN = ...

156

157

class OpSharding:

158

"""Operation sharding specification."""

159

160

Type: type[OpSharding_Type]

161

type: OpSharding_Type

162

replicate_on_last_tile_dim: bool

163

last_tile_dims: Sequence[OpSharding_Type]

164

tile_assignment_dimensions: Sequence[int]

165

tile_assignment_devices: Sequence[int]

166

iota_reshape_dims: Sequence[int]

167

iota_transpose_perm: Sequence[int]

168

tuple_shardings: Sequence[OpSharding]

169

is_shard_group: bool

170

shard_group_id: int

171

shard_group_type: OpSharding_ShardGroupType

172

173

def ParseFromString(self, s: bytes) -> None:

174

"""Parse from serialized bytes."""

175

176

def SerializeToString(self) -> bytes:

177

"""Serialize to bytes."""

178

179

def clone(self) -> OpSharding:

180

"""Create a copy of this sharding."""

181

```

182

183

### Partition Specifications

184

185

Utilities for specifying how arrays should be partitioned across device meshes.

186

187

```python { .api }

188

class PartitionSpec:

189

"""Specification for how to partition arrays."""

190

191

def __init__(self, *partitions, unreduced: Set[Any] | None = None): ...

192

193

def __hash__(self): ...

194

def __eq__(self, other): ...

195

196

class UnconstrainedSingleton:

197

"""Singleton representing unconstrained partitioning."""

198

199

def __repr__(self) -> str: ...

200

def __reduce__(self) -> Any: ...

201

202

UNCONSTRAINED_PARTITION: UnconstrainedSingleton

203

204

def canonicalize_partition(partition: Any) -> Any:

205

"""Canonicalize partition specification."""

206

```

207

208

## Usage Examples

209

210

### Basic Sharding Setup

211

212

```python

213

from jaxlib import xla_client

214

import numpy as np

215

216

# Create client with multiple devices

217

client = xla_client.make_cpu_client()

218

devices = client.local_devices()

219

220

if len(devices) >= 2:

221

# Create single device sharding

222

single_sharding = xla_client.SingleDeviceSharding(devices[0])

223

224

# Create GSPMD sharding for distribution

225

# First create OpSharding for 2-device split

226

op_sharding = xla_client.OpSharding()

227

op_sharding.type = xla_client.OpSharding_Type.OTHER

228

op_sharding.tile_assignment_dimensions = [2, 1] # Split first dimension

229

op_sharding.tile_assignment_devices = [0, 1] # Use devices 0 and 1

230

231

gspmd_sharding = xla_client.GSPMDSharding(

232

devices[:2],

233

op_sharding

234

)

235

236

print(f"GSPMD devices: {gspmd_sharding._devices}")

237

print(f"Number of devices: {gspmd_sharding._hlo_sharding.num_devices()}")

238

```

239

240

### HLO Sharding Operations

241

242

```python

243

from jaxlib import xla_client

244

245

# Create different types of HLO shardings

246

replicated = xla_client.HloSharding.replicate()

247

manual = xla_client.HloSharding.manual()

248

unknown = xla_client.HloSharding.unknown()

249

250

print(f"Replicated: {replicated.is_replicated()}")

251

print(f"Manual: {manual.is_manual()}")

252

print(f"Unknown: {unknown.is_unknown()}")

253

254

# Create sharding from string representation

255

sharding_str = "{devices=[2,1]0,1}"

256

string_sharding = xla_client.HloSharding.from_string(sharding_str)

257

print(f"Devices in sharding: {string_sharding.num_devices()}")

258

print(f"Is tiled: {string_sharding.is_tiled()}")

259

```

260

261

### Partition Specifications

262

263

```python

264

from jaxlib import xla_client

265

266

# Create partition specifications

267

spec1 = xla_client.PartitionSpec('data') # Partition along 'data' axis

268

spec2 = xla_client.PartitionSpec('batch', 'model') # Partition along two axes

269

spec3 = xla_client.PartitionSpec(None, 'data') # No partition on first axis

270

271

# Use unconstrained partition

272

unconstrained = xla_client.UNCONSTRAINED_PARTITION

273

print(f"Unconstrained: {unconstrained}")

274

275

# Canonicalize partition specs

276

canonical = xla_client.canonicalize_partition(('data', None))

277

print(f"Canonical partition: {canonical}")

278

```