or run

tessl search
Log in

Version

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/chex@0.1.x
tile.json

tessl/pypi-chex

tessl install tessl/pypi-chex@0.1.0

Comprehensive utilities library for JAX testing, debugging, and instrumentation

Agent Success

Agent success rate when using this tile

73%

Improvement

Agent success rate improvement when using this tile compared to baseline

1.92x

Baseline

Agent success rate without this tile

38%

task.mdevals/scenario-3/

Array Shape Validator

A utility for validating JAX array shapes and dimensions in machine learning pipelines.

Capabilities

Validate single array shape

  • Given a JAX array of shape (10, 20), calling validate_shape with expected_shape=(10, 20) completes without error @test
  • Given a JAX array of shape (5, 5, 5), calling validate_shape with expected_shape=(5, 5, 5) completes without error @test
  • Given a JAX array of shape (10, 20), calling validate_shape with expected_shape=(20, 10) raises an AssertionError @test

Validate array dimensionality

  • Given a JAX array of shape (10, 20), calling validate_rank with expected_rank=2 completes without error @test
  • Given a JAX array of shape (5, 5, 5), calling validate_rank with expected_rank=3 completes without error @test
  • Given a JAX array of shape (10, 20), calling validate_rank with expected_rank=3 raises an AssertionError @test

Validate matching shapes

  • Given two JAX arrays both of shape (10, 20), calling validate_equal_shapes with both arrays completes without error @test
  • Given three JAX arrays all of shape (5, 5), calling validate_equal_shapes with all three arrays completes without error @test
  • Given two JAX arrays with shapes (10, 20) and (10, 30), calling validate_equal_shapes raises an AssertionError @test

Validate flexible shape patterns

  • Given a JAX array of shape (8, 10, 20), calling validate_flexible_shape with pattern=(None, 10, 20) completes without error @test
  • Given a JAX array of shape (16, 10, 20), calling validate_flexible_shape with pattern=(None, 10, 20) completes without error @test

Implementation

@generates

API

"""Array shape validation utilities for JAX arrays."""

def validate_shape(array, expected_shape):
    """
    Validates that the given array matches the expected shape.

    Args:
        array: A JAX array to validate
        expected_shape: A tuple specifying the expected shape (e.g., (10, 20))

    Raises:
        AssertionError: If the array shape doesn't match expected_shape
    """
    pass


def validate_rank(array, expected_rank):
    """
    Validates that the given array has the expected number of dimensions.

    Args:
        array: A JAX array to validate
        expected_rank: An integer specifying the expected number of dimensions

    Raises:
        AssertionError: If the array rank doesn't match expected_rank
    """
    pass


def validate_equal_shapes(*arrays):
    """
    Validates that all provided arrays have matching shapes.

    Args:
        *arrays: Variable number of JAX arrays to compare

    Raises:
        AssertionError: If any arrays have different shapes
    """
    pass


def validate_flexible_shape(array, pattern):
    """
    Validates that an array matches a flexible shape pattern.

    Allows None as a wildcard that matches any dimension size.

    Args:
        array: A JAX array to validate
        pattern: A tuple with integer dimensions and None wildcards (e.g., (None, 10, 20))

    Raises:
        AssertionError: If the array shape doesn't match the pattern
    """
    pass

Dependencies { .dependencies }

chex { .dependency }

Provides shape and array assertion utilities for JAX.

@satisfied-by