Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
A utility module for managing and retrieving array dimension templates in JAX-based machine learning workflows.
@generates
class ShapeTemplateManager:
"""Manages dimension templates for array shapes."""
def __init__(self, **dimensions: int):
"""Initialize manager with named dimensions.
Args:
**dimensions: Keyword arguments mapping dimension names to sizes.
"""
pass
def __getitem__(self, key: str):
"""Retrieve dimensions using a string key.
Args:
key: String containing dimension names. Each character (or parenthesized
group) represents a dimension. Parentheses indicate dimensions to
be multiplied together. '*' represents wildcard/None dimension.
Returns:
Single int if key is single character, otherwise tuple of dimensions.
Parenthesized groups are flattened (multiplied together).
'*' returns None.
"""
pass
def __setitem__(self, key: str, value: int):
"""Set a dimension by name.
Args:
key: Single character dimension name.
value: Dimension size.
"""
passProvides dimension management utilities for JAX arrays.
@satisfied-by
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10