CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

sharding.mddocs/

Sharding and Distribution

Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.

Capabilities

Sharding Base Classes

Core sharding interfaces and implementations for different distribution strategies.

class Sharding:
    """Base class for all sharding implementations."""

class NamedSharding(Sharding):
    """Sharding with named mesh and partition specifications."""
    
    def __init__(
        self,
        mesh: Any,
        spec: Any,
        *,
        memory_kind: str | None = None,
        _logical_device_ids: tuple[int, ...] | None = None,
    ): ...
    
    mesh: Any
    spec: Any
    _memory_kind: str | None
    _internal_device_list: DeviceList
    _logical_device_ids: tuple[int, ...] | None

class SingleDeviceSharding(Sharding):
    """Sharding for single device placement."""
    
    def __init__(self, device: Device, *, memory_kind: str | None = None): ...
    
    _device: Device
    _memory_kind: str | None
    _internal_device_list: DeviceList

class PmapSharding(Sharding):
    """Sharding for pmap-style parallelism."""
    
    def __init__(
        self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec
    ): ...
    
    devices: list[Any]
    sharding_spec: pmap_lib.ShardingSpec
    _internal_device_list: DeviceList

class GSPMDSharding(Sharding):
    """GSPMD (General SPMD) sharding implementation."""
    
    def __init__(
        self,
        devices: Sequence[Device],
        op_sharding: OpSharding | HloSharding,
        *,
        memory_kind: str | None = None,
        _device_list: DeviceList | None = None,
    ): ...
    
    _devices: tuple[Device, ...]
    _hlo_sharding: HloSharding
    _memory_kind: str | None
    _internal_device_list: DeviceList

HLO Sharding

Low-level HLO sharding specifications for fine-grained control over data distribution.

class HloSharding:
    """HLO-level sharding specification."""
    
    @staticmethod
    def from_proto(proto: OpSharding) -> HloSharding:
        """Create HloSharding from OpSharding proto."""
    
    @staticmethod
    def from_string(sharding: str) -> HloSharding:
        """Create HloSharding from string representation."""
    
    @staticmethod
    def tuple_sharding(
        shape: Shape, shardings: Sequence[HloSharding]
    ) -> HloSharding:
        """Create tuple sharding from component shardings."""
    
    @staticmethod
    def iota_tile(
        dims: Sequence[int],
        reshape_dims: Sequence[int],
        transpose_perm: Sequence[int],
        subgroup_types: Sequence[OpSharding_Type],
    ) -> HloSharding:
        """Create iota-based tiled sharding."""
    
    @staticmethod
    def replicate() -> HloSharding:
        """Create replicated sharding (data copied to all devices)."""
    
    @staticmethod
    def manual() -> HloSharding:
        """Create manual sharding (user-controlled placement)."""
    
    @staticmethod
    def unknown() -> HloSharding:
        """Create unknown sharding (to be inferred)."""
    
    def is_replicated(self) -> bool:
        """Check if sharding is replicated."""
    
    def is_manual(self) -> bool:
        """Check if sharding is manual."""
    
    def is_unknown(self) -> bool:
        """Check if sharding is unknown."""
    
    def is_tiled(self) -> bool:
        """Check if sharding is tiled."""
    
    def is_maximal(self) -> bool:
        """Check if sharding is maximal (single device)."""
    
    def num_devices(self) -> int:
        """Get number of devices in sharding."""
    
    def tuple_elements(self) -> list[HloSharding]:
        """Get tuple element shardings."""
    
    def tile_assignment_dimensions(self) -> Sequence[int]:
        """Get tile assignment dimensions."""
    
    def tile_assignment_devices(self) -> Sequence[int]:
        """Get tile assignment device IDs."""
    
    def to_proto(self) -> OpSharding:
        """Convert to OpSharding proto."""

Operation Sharding

Protocol buffer-based sharding specifications for XLA operations.

class OpSharding_Type(enum.IntEnum):
    REPLICATED = ...
    MAXIMAL = ...
    TUPLE = ...
    OTHER = ...
    MANUAL = ...
    UNKNOWN = ...

class OpSharding:
    """Operation sharding specification."""
    
    Type: type[OpSharding_Type]
    type: OpSharding_Type
    replicate_on_last_tile_dim: bool
    last_tile_dims: Sequence[OpSharding_Type]
    tile_assignment_dimensions: Sequence[int]
    tile_assignment_devices: Sequence[int]
    iota_reshape_dims: Sequence[int]
    iota_transpose_perm: Sequence[int]
    tuple_shardings: Sequence[OpSharding]
    is_shard_group: bool
    shard_group_id: int
    shard_group_type: OpSharding_ShardGroupType
    
    def ParseFromString(self, s: bytes) -> None:
        """Parse from serialized bytes."""
    
    def SerializeToString(self) -> bytes:
        """Serialize to bytes."""
    
    def clone(self) -> OpSharding:
        """Create a copy of this sharding."""

Partition Specifications

Utilities for specifying how arrays should be partitioned across device meshes.

class PartitionSpec:
    """Specification for how to partition arrays."""
    
    def __init__(self, *partitions, unreduced: Set[Any] | None = None): ...
    
    def __hash__(self): ...
    def __eq__(self, other): ...

class UnconstrainedSingleton:
    """Singleton representing unconstrained partitioning."""
    
    def __repr__(self) -> str: ...
    def __reduce__(self) -> Any: ...

UNCONSTRAINED_PARTITION: UnconstrainedSingleton

def canonicalize_partition(partition: Any) -> Any:
    """Canonicalize partition specification."""

Usage Examples

Basic Sharding Setup

from jaxlib import xla_client
import numpy as np

# Create client with multiple devices
client = xla_client.make_cpu_client()
devices = client.local_devices()

if len(devices) >= 2:
    # Create single device sharding
    single_sharding = xla_client.SingleDeviceSharding(devices[0])
    
    # Create GSPMD sharding for distribution
    # First create OpSharding for 2-device split
    op_sharding = xla_client.OpSharding()
    op_sharding.type = xla_client.OpSharding_Type.OTHER
    op_sharding.tile_assignment_dimensions = [2, 1]  # Split first dimension
    op_sharding.tile_assignment_devices = [0, 1]     # Use devices 0 and 1
    
    gspmd_sharding = xla_client.GSPMDSharding(
        devices[:2], 
        op_sharding
    )
    
    print(f"GSPMD devices: {gspmd_sharding._devices}")
    print(f"Number of devices: {gspmd_sharding._hlo_sharding.num_devices()}")

HLO Sharding Operations

from jaxlib import xla_client

# Create different types of HLO shardings
replicated = xla_client.HloSharding.replicate()
manual = xla_client.HloSharding.manual()
unknown = xla_client.HloSharding.unknown()

print(f"Replicated: {replicated.is_replicated()}")
print(f"Manual: {manual.is_manual()}")
print(f"Unknown: {unknown.is_unknown()}")

# Create sharding from string representation
sharding_str = "{devices=[2,1]0,1}"
string_sharding = xla_client.HloSharding.from_string(sharding_str)
print(f"Devices in sharding: {string_sharding.num_devices()}")
print(f"Is tiled: {string_sharding.is_tiled()}")

Partition Specifications

from jaxlib import xla_client

# Create partition specifications
spec1 = xla_client.PartitionSpec('data')  # Partition along 'data' axis
spec2 = xla_client.PartitionSpec('batch', 'model')  # Partition along two axes
spec3 = xla_client.PartitionSpec(None, 'data')  # No partition on first axis

# Use unconstrained partition
unconstrained = xla_client.UNCONSTRAINED_PARTITION
print(f"Unconstrained: {unconstrained}")

# Canonicalize partition specs
canonical = xla_client.canonicalize_partition(('data', None))
print(f"Canonical partition: {canonical}")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json