XLA library for JAX providing low-level bindings and hardware acceleration support
—
Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.
Core sharding interfaces and implementations for different distribution strategies.
class Sharding:
"""Base class for all sharding implementations."""
class NamedSharding(Sharding):
"""Sharding with named mesh and partition specifications."""
def __init__(
self,
mesh: Any,
spec: Any,
*,
memory_kind: str | None = None,
_logical_device_ids: tuple[int, ...] | None = None,
): ...
mesh: Any
spec: Any
_memory_kind: str | None
_internal_device_list: DeviceList
_logical_device_ids: tuple[int, ...] | None
class SingleDeviceSharding(Sharding):
"""Sharding for single device placement."""
def __init__(self, device: Device, *, memory_kind: str | None = None): ...
_device: Device
_memory_kind: str | None
_internal_device_list: DeviceList
class PmapSharding(Sharding):
"""Sharding for pmap-style parallelism."""
def __init__(
self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec
): ...
devices: list[Any]
sharding_spec: pmap_lib.ShardingSpec
_internal_device_list: DeviceList
class GSPMDSharding(Sharding):
"""GSPMD (General SPMD) sharding implementation."""
def __init__(
self,
devices: Sequence[Device],
op_sharding: OpSharding | HloSharding,
*,
memory_kind: str | None = None,
_device_list: DeviceList | None = None,
): ...
_devices: tuple[Device, ...]
_hlo_sharding: HloSharding
_memory_kind: str | None
_internal_device_list: DeviceListLow-level HLO sharding specifications for fine-grained control over data distribution.
class HloSharding:
"""HLO-level sharding specification."""
@staticmethod
def from_proto(proto: OpSharding) -> HloSharding:
"""Create HloSharding from OpSharding proto."""
@staticmethod
def from_string(sharding: str) -> HloSharding:
"""Create HloSharding from string representation."""
@staticmethod
def tuple_sharding(
shape: Shape, shardings: Sequence[HloSharding]
) -> HloSharding:
"""Create tuple sharding from component shardings."""
@staticmethod
def iota_tile(
dims: Sequence[int],
reshape_dims: Sequence[int],
transpose_perm: Sequence[int],
subgroup_types: Sequence[OpSharding_Type],
) -> HloSharding:
"""Create iota-based tiled sharding."""
@staticmethod
def replicate() -> HloSharding:
"""Create replicated sharding (data copied to all devices)."""
@staticmethod
def manual() -> HloSharding:
"""Create manual sharding (user-controlled placement)."""
@staticmethod
def unknown() -> HloSharding:
"""Create unknown sharding (to be inferred)."""
def is_replicated(self) -> bool:
"""Check if sharding is replicated."""
def is_manual(self) -> bool:
"""Check if sharding is manual."""
def is_unknown(self) -> bool:
"""Check if sharding is unknown."""
def is_tiled(self) -> bool:
"""Check if sharding is tiled."""
def is_maximal(self) -> bool:
"""Check if sharding is maximal (single device)."""
def num_devices(self) -> int:
"""Get number of devices in sharding."""
def tuple_elements(self) -> list[HloSharding]:
"""Get tuple element shardings."""
def tile_assignment_dimensions(self) -> Sequence[int]:
"""Get tile assignment dimensions."""
def tile_assignment_devices(self) -> Sequence[int]:
"""Get tile assignment device IDs."""
def to_proto(self) -> OpSharding:
"""Convert to OpSharding proto."""Protocol buffer-based sharding specifications for XLA operations.
class OpSharding_Type(enum.IntEnum):
REPLICATED = ...
MAXIMAL = ...
TUPLE = ...
OTHER = ...
MANUAL = ...
UNKNOWN = ...
class OpSharding:
"""Operation sharding specification."""
Type: type[OpSharding_Type]
type: OpSharding_Type
replicate_on_last_tile_dim: bool
last_tile_dims: Sequence[OpSharding_Type]
tile_assignment_dimensions: Sequence[int]
tile_assignment_devices: Sequence[int]
iota_reshape_dims: Sequence[int]
iota_transpose_perm: Sequence[int]
tuple_shardings: Sequence[OpSharding]
is_shard_group: bool
shard_group_id: int
shard_group_type: OpSharding_ShardGroupType
def ParseFromString(self, s: bytes) -> None:
"""Parse from serialized bytes."""
def SerializeToString(self) -> bytes:
"""Serialize to bytes."""
def clone(self) -> OpSharding:
"""Create a copy of this sharding."""Utilities for specifying how arrays should be partitioned across device meshes.
class PartitionSpec:
"""Specification for how to partition arrays."""
def __init__(self, *partitions, unreduced: Set[Any] | None = None): ...
def __hash__(self): ...
def __eq__(self, other): ...
class UnconstrainedSingleton:
"""Singleton representing unconstrained partitioning."""
def __repr__(self) -> str: ...
def __reduce__(self) -> Any: ...
UNCONSTRAINED_PARTITION: UnconstrainedSingleton
def canonicalize_partition(partition: Any) -> Any:
"""Canonicalize partition specification."""from jaxlib import xla_client
import numpy as np
# Create client with multiple devices
client = xla_client.make_cpu_client()
devices = client.local_devices()
if len(devices) >= 2:
# Create single device sharding
single_sharding = xla_client.SingleDeviceSharding(devices[0])
# Create GSPMD sharding for distribution
# First create OpSharding for 2-device split
op_sharding = xla_client.OpSharding()
op_sharding.type = xla_client.OpSharding_Type.OTHER
op_sharding.tile_assignment_dimensions = [2, 1] # Split first dimension
op_sharding.tile_assignment_devices = [0, 1] # Use devices 0 and 1
gspmd_sharding = xla_client.GSPMDSharding(
devices[:2],
op_sharding
)
print(f"GSPMD devices: {gspmd_sharding._devices}")
print(f"Number of devices: {gspmd_sharding._hlo_sharding.num_devices()}")from jaxlib import xla_client
# Create different types of HLO shardings
replicated = xla_client.HloSharding.replicate()
manual = xla_client.HloSharding.manual()
unknown = xla_client.HloSharding.unknown()
print(f"Replicated: {replicated.is_replicated()}")
print(f"Manual: {manual.is_manual()}")
print(f"Unknown: {unknown.is_unknown()}")
# Create sharding from string representation
sharding_str = "{devices=[2,1]0,1}"
string_sharding = xla_client.HloSharding.from_string(sharding_str)
print(f"Devices in sharding: {string_sharding.num_devices()}")
print(f"Is tiled: {string_sharding.is_tiled()}")from jaxlib import xla_client
# Create partition specifications
spec1 = xla_client.PartitionSpec('data') # Partition along 'data' axis
spec2 = xla_client.PartitionSpec('batch', 'model') # Partition along two axes
spec3 = xla_client.PartitionSpec(None, 'data') # No partition on first axis
# Use unconstrained partition
unconstrained = xla_client.UNCONSTRAINED_PARTITION
print(f"Unconstrained: {unconstrained}")
# Canonicalize partition specs
canonical = xla_client.canonicalize_partition(('data', None))
print(f"Canonical partition: {canonical}")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib