Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
A utility module for validating neural network layer parameters with support for optional components.
Neural networks often have layers with required and optional parameters. For example, a dense layer always has weights but may optionally have bias vectors or normalization layers. When implementing validation functions for such architectures, you need to validate required parameters while gracefully handling optional ones.
Implement a validation module that checks neural network layer parameters. The module should validate:
The validation should work correctly whether optional parameters are provided or omitted.
@generates
def validate_dense_layer(inputs, weights, bias=None, scale=None):
"""
Validate parameters for a dense neural network layer.
Validates that:
- inputs has shape (batch_size, input_features)
- weights has shape (input_features, output_features)
- bias (if provided) has shape (output_features,)
- scale (if provided) has shape (output_features,)
Parameters:
- inputs: Input tensor with shape (batch_size, input_features)
- weights: Weight matrix with shape (input_features, output_features)
- bias: Optional bias vector with shape (output_features,), defaults to None
- scale: Optional scale vector with shape (output_features,), defaults to None
Raises:
- AssertionError: If any validation check fails
"""
pass
def validate_conv_layer(inputs, kernel, bias=None):
"""
Validate parameters for a convolutional layer.
Validates that:
- inputs has 4 dimensions (batch, height, width, channels)
- kernel has 4 dimensions (kernel_height, kernel_width, in_channels, out_channels)
- bias (if provided) has shape (out_channels,)
- inputs and kernel have matching channel dimensions
Parameters:
- inputs: Input tensor with shape (batch, height, width, in_channels)
- kernel: Convolution kernel with shape (kh, kw, in_channels, out_channels)
- bias: Optional bias vector with shape (out_channels,), defaults to None
Raises:
- AssertionError: If any validation check fails
"""
passTesting and validation library for JAX code providing shape assertions and conditional validation utilities.
@satisfied-by
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10