Open Neural Network Exchange for AI model interoperability and machine learning frameworks
—
Comprehensive validation functions to verify ONNX model correctness, including graph structure, node compatibility, type consistency, and operator definitions. These functions ensure models conform to the ONNX specification and can be executed correctly by runtimes.
Validate complete ONNX models with comprehensive checks for structure, types, and operator compatibility.
def check_model(
model: ModelProto | str | bytes | os.PathLike,
full_check: bool = False,
skip_opset_compatibility_check: bool = False,
) -> None:
"""
Check the consistency of a model. An exception is raised if the test fails.
Parameters:
- model: Model to check.
- full_check: If True, the function also checks for shapes that can be inferred.
- skip_opset_compatibility_check: If True, the function skips the check for
opset compatibility.
"""Validate computation graphs for structural correctness and data flow consistency.
def check_graph(graph, ctx=DEFAULT_CONTEXT):
"""
Validate a GraphProto.
Parameters:
- graph: GraphProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If graph structure is invalid
"""Validate individual nodes for operator compatibility and attribute correctness.
def check_node(node, ctx=DEFAULT_CONTEXT):
"""
Validate a NodeProto.
Parameters:
- node: NodeProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If node is invalid (unknown operator, wrong attributes, etc.)
"""Validate user-defined functions for correctness and compatibility.
def check_function(function, ctx=DEFAULT_CONTEXT):
"""
Validate a FunctionProto.
Parameters:
- function: FunctionProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If function definition is invalid
"""Validate tensor data structures for type consistency and data integrity.
def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
"""
Validate a TensorProto.
Parameters:
- tensor: TensorProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If tensor data is inconsistent or malformed
"""
def check_sparse_tensor(
sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
"""
Validate a SparseTensorProto.
Parameters:
- sparse: SparseTensorProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If sparse tensor structure is invalid
"""Validate node attributes for type correctness and operator compatibility.
def check_attribute(attribute, ctx=DEFAULT_CONTEXT):
"""
Validate an AttributeProto.
Parameters:
- attribute: AttributeProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If attribute type or value is invalid
"""
def check_value_info(value_info, ctx=DEFAULT_CONTEXT):
"""
Validate a ValueInfoProto.
Parameters:
- value_info: ValueInfoProto to validate
- ctx: Validation context for configuration
Returns:
None (raises ValidationError if invalid)
Raises:
ValidationError: If value info type or shape is invalid
"""Important constants and context objects for validation configuration:
DEFAULT_CONTEXT # Default validation context with standard settings
MAXIMUM_PROTOBUF # Maximum protobuf file size limit (2GB)
class ValidationError(Exception):
"""
Exception raised when ONNX validation fails.
Contains detailed information about validation failures
including location and reason for the error.
"""
# Checker context from C++ implementation
C # Module containing C++ checker implementationimport onnx
from onnx.checker import check_model, ValidationError
try:
# Load and validate a model
model = onnx.load_model("path/to/model.onnx")
check_model(model)
print("Model is valid!")
except ValidationError as e:
print(f"Model validation failed: {e}")
except Exception as e:
print(f"Error loading model: {e}")import onnx
from onnx.checker import check_model, ValidationError
# Load model
model = onnx.load_model("complex_model.onnx")
try:
# Perform full validation (slower but comprehensive)
check_model(model, full_check=True)
print("Model passed comprehensive validation!")
except ValidationError as e:
print(f"Validation failed: {e}")
# Print specific validation errors
print("This could indicate:")
print("- Unknown or unsupported operators")
print("- Type mismatches between connected nodes")
print("- Missing required attributes")
print("- Invalid tensor shapes or data types")import onnx
from onnx import helper, checker, TensorProto
# Create a simple node
node = helper.make_node('Relu', ['input'], ['output'])
try:
# Validate the node
checker.check_node(node)
print("Node is valid!")
except checker.ValidationError as e:
print(f"Node validation failed: {e}")
# Create and validate a tensor
import numpy as np
data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
tensor = helper.make_tensor('test_tensor', TensorProto.FLOAT, [3], data)
try:
checker.check_tensor(tensor)
print("Tensor is valid!")
except checker.ValidationError as e:
print(f"Tensor validation failed: {e}")import onnx
from onnx import helper, checker, TensorProto
def build_and_validate_model():
# Define inputs and outputs
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 3, 224, 224])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1000])
# Create nodes
conv_node = helper.make_node(
'Conv', ['X', 'W'], ['conv_out'],
kernel_shape=[3, 3], pads=[1, 1, 1, 1]
)
# Validate node immediately
try:
checker.check_node(conv_node)
except checker.ValidationError as e:
print(f"Invalid node: {e}")
return None
relu_node = helper.make_node('Relu', ['conv_out'], ['Y'])
# Create weight tensor
weight_data = np.random.randn(1000, 3, 3, 3).astype(np.float32)
weight = helper.make_tensor('W', TensorProto.FLOAT,
weight_data.shape, weight_data)
# Create graph
graph = helper.make_graph(
[conv_node, relu_node],
'test_model',
[X], [Y], [weight]
)
# Validate graph
try:
checker.check_graph(graph)
except checker.ValidationError as e:
print(f"Invalid graph: {e}")
return None
# Create and validate model
model = helper.make_model(graph)
try:
checker.check_model(model)
print("Model construction and validation successful!")
return model
except checker.ValidationError as e:
print(f"Invalid model: {e}")
return None
# Build and validate
model = build_and_validate_model()import onnx
from onnx import checker
# For advanced use cases, custom validation contexts can be used
# to configure specific validation behavior
try:
model = onnx.load_model("model.onnx")
# Use default context for standard validation
checker.check_model(model, full_check=True)
# Skip operator set compatibility for experimental models
checker.check_model(model, skip_opset_compatibility_check=True)
except checker.ValidationError as e:
# Handle specific validation errors
error_msg = str(e)
if "operator" in error_msg.lower():
print("Operator-related validation error")
elif "type" in error_msg.lower():
print("Type-related validation error")
elif "shape" in error_msg.lower():
print("Shape-related validation error")
else:
print("General validation error")
print(f"Details: {e}")Install with Tessl CLI
npx tessl i tessl/pypi-onnx