Open Neural Network Exchange for AI model interoperability and machine learning frameworks
—
Automatic shape and type inference for ONNX model graphs, enabling optimization and validation of tensor shapes throughout the computation graph. Shape inference is essential for model optimization, validation, and runtime preparation.
Infer shapes and types for all values in a model's computation graph.
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> ModelProto:
"""
Apply shape inference to the provided ModelProto.
Inferred shapes are added to the value_info field of the graph.
If the inferred values conflict with values already provided in the
graph, that means that the provided values are invalid (or there is a
bug in shape inference), and the result is unspecified.
Parameters:
- model: Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
- check_type: Checks the type-equality for input and output
- strict_mode: Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error
- data_prop: Enables data propagation for limited operators to perform shape computation
Returns:
ModelProto: model with inferred shape information
"""
def infer_shapes_path(
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> None:
"""
Take model path for shape_inference same as infer_shape; it support >2GB models
Directly output the inferred model to the output_path; Default is the original model path
"""Perform shape inference for individual nodes and functions.
def infer_node_outputs(
schema: onnx.defs.OpSchema,
node: onnx.NodeProto,
input_types: dict[str, onnx.TypeProto],
input_data: dict[str, onnx.TensorProto] | None = None,
input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
opset_imports: list[onnx.OperatorSetIdProto] | None = None,
ir_version: int = onnx.IR_VERSION,
) -> dict[str, onnx.TypeProto]:
"""
Infer output types for a single node.
Parameters:
- schema: OpSchema for the node's operator
- node: NodeProto to infer outputs for
- input_types: dict mapping input names to TypeProto for node inputs
- input_data: Optional input data for data-dependent inference
- input_sparse_data: Optional sparse input data
- opset_imports: Optional opset imports
- ir_version: IR version to use
Returns:
dict[str, onnx.TypeProto]: Inferred output types
Raises:
InferenceError: If inference fails for the node
"""
def infer_function_output_types(
function: FunctionProto,
input_types: Sequence[TypeProto],
attributes: Sequence[AttributeProto],
) -> list[TypeProto]:
"""
Apply type-and-shape-inference to given function body, with given input types
and given input attribute values.
"""Exception types for shape inference errors.
class InferenceError(Exception):
"""
Exception raised when shape inference fails.
Contains detailed information about why inference failed,
including the specific node or operation that caused the error.
"""import onnx
from onnx import shape_inference
# Load a model without shape information
model = onnx.load_model("model_without_shapes.onnx")
try:
# Perform shape inference
inferred_model = shape_inference.infer_shapes(model)
# Save model with inferred shapes
onnx.save_model(inferred_model, "model_with_shapes.onnx")
print("Shape inference completed successfully!")
except shape_inference.InferenceError as e:
print(f"Shape inference failed: {e}")import onnx
from onnx import shape_inference
model = onnx.load_model("complex_model.onnx")
try:
# Perform shape inference with type checking and data propagation
inferred_model = shape_inference.infer_shapes(
model,
check_type=True, # Check for type compatibility
strict_mode=True, # Apply strict inference rules
data_prop=True # Enable data value propagation
)
print("Advanced shape inference completed!")
# Check the inferred shapes
for value_info in inferred_model.graph.value_info:
print(f"Value: {value_info.name}")
if value_info.type.HasField('tensor_type'):
tensor_type = value_info.type.tensor_type
shape = [dim.dim_value if dim.HasField('dim_value') else dim.dim_param
for dim in tensor_type.shape.dim]
print(f" Shape: {shape}")
print(f" Type: {tensor_type.elem_type}")
except shape_inference.InferenceError as e:
print(f"Shape inference failed: {e}")import onnx
from onnx import shape_inference
# For models larger than 2GB, use the path-based inference
try:
shape_inference.infer_shapes_path(
model_path="large_model.onnx",
output_path="large_model_with_shapes.onnx",
check_type=True
)
print("Shape inference for large model completed!")
except shape_inference.InferenceError as e:
print(f"Shape inference failed: {e}")import onnx
from onnx import helper, shape_inference, defs, TensorProto
# Create input type information
input_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [1, 3, 224, 224])
# Create a convolution node
conv_node = helper.make_node(
'Conv',
inputs=['input', 'weight'],
outputs=['output'],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
)
# Get the operator schema
conv_schema = defs.get_schema('Conv')
# Create weight type information
weight_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [64, 3, 3, 3])
try:
# Infer output types for the node
output_types = shape_inference.infer_node_outputs(
conv_schema,
conv_node,
[input_type, weight_type]
)
print(f"Inferred output shape for Conv node:")
for i, output_type in enumerate(output_types):
if output_type.HasField('tensor_type'):
shape = [dim.dim_value or dim.dim_param
for dim in output_type.tensor_type.shape.dim]
print(f" Output {i}: {shape}")
except shape_inference.InferenceError as e:
print(f"Node-level inference failed: {e}")import onnx
from onnx import shape_inference
def debug_shape_inference(model_path):
"""Debug shape inference issues by examining the model structure."""
model = onnx.load_model(model_path)
# Check if model has input shapes defined
print("Input information:")
for input_info in model.graph.input:
print(f" {input_info.name}: {input_info.type}")
# Check for nodes that might cause issues
print("\nNodes in graph:")
for i, node in enumerate(model.graph.node):
print(f" {i}: {node.op_type} ({node.name or 'unnamed'})")
print(f" Inputs: {list(node.input)}")
print(f" Outputs: {list(node.output)}")
try:
# Attempt shape inference
inferred_model = shape_inference.infer_shapes(model, check_type=True)
print("\nShape inference successful!")
# Show inferred shapes
print("Inferred value information:")
for value_info in inferred_model.graph.value_info:
print(f" {value_info.name}: {value_info.type}")
except shape_inference.InferenceError as e:
print(f"\nShape inference failed: {e}")
print("Common causes:")
print("- Missing input shape information")
print("- Unsupported operators")
print("- Type mismatches between connected nodes")
print("- Missing initializer tensors")
# Debug a problematic model
# debug_shape_inference("problematic_model.onnx")import onnx
from onnx import helper, shape_inference, TensorProto
import numpy as np
def create_model_with_inference():
"""Create a model and automatically infer shapes."""
# Define input (without complete shape information)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 784])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 10])
# Create weight matrices
W1 = np.random.randn(784, 128).astype(np.float32)
W2 = np.random.randn(128, 10).astype(np.float32)
W1_tensor = helper.make_tensor('W1', TensorProto.FLOAT, W1.shape, W1)
W2_tensor = helper.make_tensor('W2', TensorProto.FLOAT, W2.shape, W2)
# Create computation nodes
matmul1 = helper.make_node('MatMul', ['X', 'W1'], ['hidden'])
relu = helper.make_node('Relu', ['hidden'], ['hidden_relu'])
matmul2 = helper.make_node('MatMul', ['hidden_relu', 'W2'], ['Y'])
# Create graph
graph = helper.make_graph(
[matmul1, relu, matmul2],
'mlp_model',
[X], [Y],
[W1_tensor, W2_tensor]
)
# Create model
model = helper.make_model(graph)
try:
# Perform shape inference to fill in intermediate shapes
inferred_model = shape_inference.infer_shapes(model)
print("Model created with shape inference:")
for value_info in inferred_model.graph.value_info:
if value_info.type.HasField('tensor_type'):
shape = [dim.dim_value if dim.HasField('dim_value') else '?'
for dim in value_info.type.tensor_type.shape.dim]
print(f" {value_info.name}: {shape}")
return inferred_model
except shape_inference.InferenceError as e:
print(f"Shape inference failed: {e}")
return model
# Create model with automatic shape inference
model = create_model_with_inference()Install with Tessl CLI
npx tessl i tessl/pypi-onnx