CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-onnx

Open Neural Network Exchange for AI model interoperability and machine learning frameworks

Pending
Overview
Eval results
Files

shape-inference.mddocs/

Shape Inference

Automatic shape and type inference for ONNX model graphs, enabling optimization and validation of tensor shapes throughout the computation graph. Shape inference is essential for model optimization, validation, and runtime preparation.

Capabilities

Model Shape Inference

Infer shapes and types for all values in a model's computation graph.

def infer_shapes(
    model: ModelProto | bytes,
    check_type: bool = False,
    strict_mode: bool = False,
    data_prop: bool = False,
) -> ModelProto:
    """
    Apply shape inference to the provided ModelProto.

    Inferred shapes are added to the value_info field of the graph.

    If the inferred values conflict with values already provided in the
    graph, that means that the provided values are invalid (or there is a
    bug in shape inference), and the result is unspecified.

    Parameters:
    - model: Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
    - check_type: Checks the type-equality for input and output
    - strict_mode: Stricter shape inference, it will throw errors if any;
        Otherwise, simply stop if any error
    - data_prop: Enables data propagation for limited operators to perform shape computation

    Returns:
    ModelProto: model with inferred shape information
    """

def infer_shapes_path(
    model_path: str | os.PathLike,
    output_path: str | os.PathLike = "",
    check_type: bool = False,
    strict_mode: bool = False,
    data_prop: bool = False,
) -> None:
    """
    Take model path for shape_inference same as infer_shape; it support >2GB models
    Directly output the inferred model to the output_path; Default is the original model path
    """

Node-Level Shape Inference

Perform shape inference for individual nodes and functions.

def infer_node_outputs(
    schema: onnx.defs.OpSchema,
    node: onnx.NodeProto,
    input_types: dict[str, onnx.TypeProto],
    input_data: dict[str, onnx.TensorProto] | None = None,
    input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
    opset_imports: list[onnx.OperatorSetIdProto] | None = None,
    ir_version: int = onnx.IR_VERSION,
) -> dict[str, onnx.TypeProto]:
    """
    Infer output types for a single node.

    Parameters:
    - schema: OpSchema for the node's operator
    - node: NodeProto to infer outputs for
    - input_types: dict mapping input names to TypeProto for node inputs
    - input_data: Optional input data for data-dependent inference
    - input_sparse_data: Optional sparse input data
    - opset_imports: Optional opset imports
    - ir_version: IR version to use

    Returns:
    dict[str, onnx.TypeProto]: Inferred output types

    Raises:
    InferenceError: If inference fails for the node
    """

def infer_function_output_types(
    function: FunctionProto,
    input_types: Sequence[TypeProto],
    attributes: Sequence[AttributeProto],
) -> list[TypeProto]:
    """
    Apply type-and-shape-inference to given function body, with given input types
    and given input attribute values.
    """

Shape Inference Exceptions

Exception types for shape inference errors.

class InferenceError(Exception):
    """
    Exception raised when shape inference fails.
    
    Contains detailed information about why inference failed,
    including the specific node or operation that caused the error.
    """

Usage Examples

Basic Shape Inference

import onnx
from onnx import shape_inference

# Load a model without shape information
model = onnx.load_model("model_without_shapes.onnx")

try:
    # Perform shape inference
    inferred_model = shape_inference.infer_shapes(model)
    
    # Save model with inferred shapes
    onnx.save_model(inferred_model, "model_with_shapes.onnx")
    print("Shape inference completed successfully!")
    
except shape_inference.InferenceError as e:
    print(f"Shape inference failed: {e}")

Advanced Shape Inference Options

import onnx
from onnx import shape_inference

model = onnx.load_model("complex_model.onnx")

try:
    # Perform shape inference with type checking and data propagation
    inferred_model = shape_inference.infer_shapes(
        model,
        check_type=True,        # Check for type compatibility
        strict_mode=True,       # Apply strict inference rules
        data_prop=True          # Enable data value propagation
    )
    
    print("Advanced shape inference completed!")
    
    # Check the inferred shapes
    for value_info in inferred_model.graph.value_info:
        print(f"Value: {value_info.name}")
        if value_info.type.HasField('tensor_type'):
            tensor_type = value_info.type.tensor_type
            shape = [dim.dim_value if dim.HasField('dim_value') else dim.dim_param 
                    for dim in tensor_type.shape.dim]
            print(f"  Shape: {shape}")
            print(f"  Type: {tensor_type.elem_type}")
            
except shape_inference.InferenceError as e:
    print(f"Shape inference failed: {e}")

Shape Inference for Large Models

import onnx
from onnx import shape_inference

# For models larger than 2GB, use the path-based inference
try:
    shape_inference.infer_shapes_path(
        model_path="large_model.onnx",
        output_path="large_model_with_shapes.onnx",
        check_type=True
    )
    print("Shape inference for large model completed!")
    
except shape_inference.InferenceError as e:
    print(f"Shape inference failed: {e}")

Node-Level Shape Inference

import onnx
from onnx import helper, shape_inference, defs, TensorProto

# Create input type information
input_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [1, 3, 224, 224])

# Create a convolution node
conv_node = helper.make_node(
    'Conv',
    inputs=['input', 'weight'],
    outputs=['output'],
    kernel_shape=[3, 3],
    pads=[1, 1, 1, 1],
    strides=[1, 1]
)

# Get the operator schema
conv_schema = defs.get_schema('Conv')

# Create weight type information
weight_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [64, 3, 3, 3])

try:
    # Infer output types for the node
    output_types = shape_inference.infer_node_outputs(
        conv_schema, 
        conv_node, 
        [input_type, weight_type]
    )
    
    print(f"Inferred output shape for Conv node:")
    for i, output_type in enumerate(output_types):
        if output_type.HasField('tensor_type'):
            shape = [dim.dim_value or dim.dim_param 
                    for dim in output_type.tensor_type.shape.dim]
            print(f"  Output {i}: {shape}")
            
except shape_inference.InferenceError as e:
    print(f"Node-level inference failed: {e}")

Debugging Shape Inference Issues

import onnx
from onnx import shape_inference

def debug_shape_inference(model_path):
    """Debug shape inference issues by examining the model structure."""
    
    model = onnx.load_model(model_path)
    
    # Check if model has input shapes defined
    print("Input information:")
    for input_info in model.graph.input:
        print(f"  {input_info.name}: {input_info.type}")
    
    # Check for nodes that might cause issues
    print("\nNodes in graph:")
    for i, node in enumerate(model.graph.node):
        print(f"  {i}: {node.op_type} ({node.name or 'unnamed'})")
        print(f"    Inputs: {list(node.input)}")
        print(f"    Outputs: {list(node.output)}")
    
    try:
        # Attempt shape inference
        inferred_model = shape_inference.infer_shapes(model, check_type=True)
        print("\nShape inference successful!")
        
        # Show inferred shapes
        print("Inferred value information:")
        for value_info in inferred_model.graph.value_info:
            print(f"  {value_info.name}: {value_info.type}")
            
    except shape_inference.InferenceError as e:
        print(f"\nShape inference failed: {e}")
        print("Common causes:")
        print("- Missing input shape information")
        print("- Unsupported operators")
        print("- Type mismatches between connected nodes")
        print("- Missing initializer tensors")

# Debug a problematic model
# debug_shape_inference("problematic_model.onnx")

Integrating Shape Inference with Model Construction

import onnx
from onnx import helper, shape_inference, TensorProto
import numpy as np

def create_model_with_inference():
    """Create a model and automatically infer shapes."""
    
    # Define input (without complete shape information)
    X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 784])
    Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 10])
    
    # Create weight matrices
    W1 = np.random.randn(784, 128).astype(np.float32)
    W2 = np.random.randn(128, 10).astype(np.float32)
    
    W1_tensor = helper.make_tensor('W1', TensorProto.FLOAT, W1.shape, W1)
    W2_tensor = helper.make_tensor('W2', TensorProto.FLOAT, W2.shape, W2)
    
    # Create computation nodes
    matmul1 = helper.make_node('MatMul', ['X', 'W1'], ['hidden'])
    relu = helper.make_node('Relu', ['hidden'], ['hidden_relu'])
    matmul2 = helper.make_node('MatMul', ['hidden_relu', 'W2'], ['Y'])
    
    # Create graph
    graph = helper.make_graph(
        [matmul1, relu, matmul2],
        'mlp_model',
        [X], [Y], 
        [W1_tensor, W2_tensor]
    )
    
    # Create model
    model = helper.make_model(graph)
    
    try:
        # Perform shape inference to fill in intermediate shapes
        inferred_model = shape_inference.infer_shapes(model)
        
        print("Model created with shape inference:")
        for value_info in inferred_model.graph.value_info:
            if value_info.type.HasField('tensor_type'):
                shape = [dim.dim_value if dim.HasField('dim_value') else '?'
                        for dim in value_info.type.tensor_type.shape.dim]
                print(f"  {value_info.name}: {shape}")
        
        return inferred_model
        
    except shape_inference.InferenceError as e:
        print(f"Shape inference failed: {e}")
        return model

# Create model with automatic shape inference
model = create_model_with_inference()

Install with Tessl CLI

npx tessl i tessl/pypi-onnx

docs

backend-integration.md

index.md

model-composition.md

model-construction.md

model-hub.md

model-io.md

model-validation.md

numpy-integration.md

operator-definitions.md

reference-implementation.md

shape-inference.md

text-processing.md

version-conversion.md

tile.json