CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

task.mdevals/scenario-10/

Neural Network Layer Validator

A utility module for validating neural network layer parameters with support for optional components.

Background

Neural networks often have layers with required and optional parameters. For example, a dense layer always has weights but may optionally have bias vectors or normalization layers. When implementing validation functions for such architectures, you need to validate required parameters while gracefully handling optional ones.

Requirements

Implement a validation module that checks neural network layer parameters. The module should validate:

  1. Input data: Always validate that input tensors have the correct batch dimension and feature count
  2. Weights: Always validate that weight matrices have compatible dimensions
  3. Optional bias: Only validate bias vectors when they are provided (not None)
  4. Optional scale: Only validate scale parameters when they are provided (not None)

The validation should work correctly whether optional parameters are provided or omitted.

Implementation

@generates

API

def validate_dense_layer(inputs, weights, bias=None, scale=None):
    """
    Validate parameters for a dense neural network layer.

    Validates that:
    - inputs has shape (batch_size, input_features)
    - weights has shape (input_features, output_features)
    - bias (if provided) has shape (output_features,)
    - scale (if provided) has shape (output_features,)

    Parameters:
    - inputs: Input tensor with shape (batch_size, input_features)
    - weights: Weight matrix with shape (input_features, output_features)
    - bias: Optional bias vector with shape (output_features,), defaults to None
    - scale: Optional scale vector with shape (output_features,), defaults to None

    Raises:
    - AssertionError: If any validation check fails
    """
    pass

def validate_conv_layer(inputs, kernel, bias=None):
    """
    Validate parameters for a convolutional layer.

    Validates that:
    - inputs has 4 dimensions (batch, height, width, channels)
    - kernel has 4 dimensions (kernel_height, kernel_width, in_channels, out_channels)
    - bias (if provided) has shape (out_channels,)
    - inputs and kernel have matching channel dimensions

    Parameters:
    - inputs: Input tensor with shape (batch, height, width, in_channels)
    - kernel: Convolution kernel with shape (kh, kw, in_channels, out_channels)
    - bias: Optional bias vector with shape (out_channels,), defaults to None

    Raises:
    - AssertionError: If any validation check fails
    """
    pass

Test Cases

  • Validating a dense layer with inputs (32, 128), weights (128, 64), and no optional parameters succeeds @test
  • Validating a dense layer with inputs (32, 128), weights (128, 64), bias (64,), and scale (64,) succeeds @test
  • Validating a dense layer with inputs (32, 128), weights (128, 64), and only bias (64,) succeeds @test
  • Validating a dense layer with inputs (32, 128), weights (128, 64), bias (64,), and scale (32,) fails due to incorrect scale shape @test
  • Validating a conv layer with inputs (16, 28, 28, 3), kernel (3, 3, 3, 32), and no bias succeeds @test
  • Validating a conv layer with inputs (16, 28, 28, 3), kernel (3, 3, 3, 32), and bias (32,) succeeds @test
  • Validating a conv layer with inputs (16, 28, 28, 3), kernel (3, 3, 5, 32), and no bias fails due to mismatched input channels @test

Dependencies { .dependencies }

chex { .dependency }

Testing and validation library for JAX code providing shape assertions and conditional validation utilities.

@satisfied-by

Install with Tessl CLI

npx tessl i tessl/pypi-chex

tile.json