Open Neural Network Exchange for AI model interoperability and machine learning frameworks
—
Access to ONNX operator schemas, type definitions, and version compatibility information for all supported operators across different domains. This module provides programmatic access to the complete ONNX operator specification.
Retrieve operator schemas and metadata for validation and code generation.
def get_schema(op_type, max_inclusive_version=None, domain=""):
"""
Get operator schema for a specific operator.
Parameters:
- op_type: Name of the operator (e.g., 'Conv', 'Relu', 'Add')
- max_inclusive_version: Maximum opset version to consider
- domain: Operator domain (empty string for ONNX domain)
Returns:
OpSchema: Schema object with operator definition and constraints
Raises:
SchemaError: If operator is not found or version is invalid
"""
def has(op_type, domain=""):
"""
Check if an operator schema exists.
Parameters:
- op_type: Name of the operator to check
- domain: Operator domain (empty string for ONNX domain)
Returns:
bool: True if operator schema exists, False otherwise
"""
def get_all_schemas():
"""
Get all operator schemas for the current opset version.
Returns:
list[OpSchema]: List of all available operator schemas
"""
def get_all_schemas_with_history():
"""
Get all operator schemas including historical versions.
Returns:
list[OpSchema]: List of all schemas across all versions
"""Access to operators that are defined as functions rather than primitives.
def get_function_ops():
"""
Get list of operators that are defined as functions.
Returns:
list[str]: Names of operators defined as functions
"""Get current and supported opset versions.
def onnx_opset_version():
"""
Get the current ONNX opset version.
Returns:
int: Current opset version number
"""Standard domain identifiers for ONNX operators.
ONNX_DOMAIN = "" # Standard ONNX domain (empty string)
ONNX_ML_DOMAIN = "ai.onnx.ml" # ONNX ML domain for traditional ML operators
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training" # Training operators domainClasses for working with operator schemas and handling errors.
class OpSchema:
"""
Operator schema containing definition, constraints, and metadata.
Key properties:
- name: Operator name
- domain: Operator domain
- since_version: Minimum opset version
- doc: Documentation string
- attributes: Dict of attribute schemas
- inputs: List of input specifications
- outputs: List of output specifications
- type_constraints: List of type constraints for inputs/outputs
"""
class SchemaError(Exception):
"""
Exception raised when operator schema operations fail.
Thrown when requesting schemas for unknown operators,
unsupported versions, or invalid operator definitions.
"""
# C++ implementation reference
C = ... # Module containing C++ operator definitions implementationimport onnx
from onnx import defs
# Check if an operator exists
if defs.has('Conv'):
print("Conv operator is available")
# Get schema for a specific operator
conv_schema = defs.get_schema('Conv')
print(f"Conv operator documentation: {conv_schema.doc}")
print(f"Conv available since opset version: {conv_schema.since_version}")
# Get schema for a specific version
relu_schema = defs.get_schema('Relu', max_inclusive_version=6)
print(f"Relu schema for opset <= 6: {relu_schema.name}")
# Check ML domain operators
if defs.has('LinearRegressor', domain=defs.ONNX_ML_DOMAIN):
ml_schema = defs.get_schema('LinearRegressor', domain=defs.ONNX_ML_DOMAIN)
print(f"ML operator: {ml_schema.name}")import onnx
from onnx import defs
# Get current opset version
current_opset = defs.onnx_opset_version()
print(f"Current ONNX opset version: {current_opset}")
# List all available operators
all_schemas = defs.get_all_schemas()
print(f"Total operators available: {len(all_schemas)}")
# Group operators by domain
operators_by_domain = {}
for schema in all_schemas:
domain = schema.domain or "ONNX"
if domain not in operators_by_domain:
operators_by_domain[domain] = []
operators_by_domain[domain].append(schema.name)
for domain, ops in operators_by_domain.items():
print(f"{domain} domain: {len(ops)} operators")
print(f" Examples: {', '.join(sorted(ops)[:5])}")
# Find function-based operators
function_ops = defs.get_function_ops()
print(f"Function-based operators: {function_ops}")import onnx
from onnx import defs
def analyze_operator(op_name, domain=""):
"""Analyze an operator's schema in detail."""
try:
schema = defs.get_schema(op_name, domain=domain)
print(f"Operator: {schema.name}")
print(f"Domain: {schema.domain or 'ONNX'}")
print(f"Since version: {schema.since_version}")
print(f"Documentation: {schema.doc[:200]}...")
# Analyze inputs
print(f"\nInputs ({len(schema.inputs)}):")
for i, input_spec in enumerate(schema.inputs):
print(f" {i}: {input_spec.name} - {input_spec.description}")
print(f" Type constraints: {input_spec.type_str}")
# Analyze outputs
print(f"\nOutputs ({len(schema.outputs)}):")
for i, output_spec in enumerate(schema.outputs):
print(f" {i}: {output_spec.name} - {output_spec.description}")
print(f" Type constraints: {output_spec.type_str}")
# Analyze attributes
print(f"\nAttributes ({len(schema.attributes)}):")
for attr_name, attr_spec in schema.attributes.items():
required = "required" if attr_spec.required else "optional"
print(f" {attr_name} ({required}): {attr_spec.description}")
print(f" Type: {attr_spec.type}")
if hasattr(attr_spec, 'default_value') and attr_spec.default_value:
print(f" Default: {attr_spec.default_value}")
# Check if it's a function
if hasattr(schema, 'function_body') and schema.function_body:
print(f"\nFunction-based operator with {len(schema.function_body.node)} nodes")
except defs.SchemaError as e:
print(f"Schema error for {op_name}: {e}")
# Analyze some common operators
analyze_operator('Conv')
print("\n" + "="*50 + "\n")
analyze_operator('BatchNormalization')
print("\n" + "="*50 + "\n")
analyze_operator('LinearRegressor', domain=defs.ONNX_ML_DOMAIN)import onnx
from onnx import defs
def check_opset_compatibility(model_path):
"""Check if a model's operators are compatible with current opset."""
model = onnx.load_model(model_path)
current_opset = defs.onnx_opset_version()
# Get model's opset imports
model_opsets = {}
for opset_import in model.opset_import:
model_opsets[opset_import.domain] = opset_import.version
print(f"Model opset versions: {model_opsets}")
print(f"Current ONNX opset: {current_opset}")
# Check each node's operator
incompatible_nodes = []
for node in model.graph.node:
op_type = node.op_type
domain = node.domain or ""
try:
# Get schema for model's opset version
model_opset_version = model_opsets.get(domain, model_opsets.get("", 1))
schema = defs.get_schema(op_type,
max_inclusive_version=model_opset_version,
domain=domain)
# Check if operator exists in current version
if not defs.has(op_type, domain=domain):
incompatible_nodes.append((node.name or f"node_{node.op_type}",
op_type, domain, "not available"))
elif schema.since_version > current_opset:
incompatible_nodes.append((node.name or f"node_{node.op_type}",
op_type, domain, "version too high"))
except defs.SchemaError:
incompatible_nodes.append((node.name or f"node_{node.op_type}",
op_type, domain, "schema not found"))
if incompatible_nodes:
print(f"\nFound {len(incompatible_nodes)} incompatible nodes:")
for node_name, op_type, domain, reason in incompatible_nodes:
print(f" {node_name}: {op_type}@{domain} - {reason}")
else:
print("\nAll operators are compatible with current opset!")
# Example usage (commented out)
# check_opset_compatibility("model.onnx")import onnx
from onnx import defs
def find_custom_operators(model_path):
"""Find custom (non-standard) operators in a model."""
model = onnx.load_model(model_path)
standard_domains = {"", defs.ONNX_ML_DOMAIN, defs.AI_ONNX_PREVIEW_TRAINING_DOMAIN}
custom_operators = []
for node in model.graph.node:
op_type = node.op_type
domain = node.domain or ""
# Check if it's a custom domain
if domain not in standard_domains:
custom_operators.append((op_type, domain, "custom domain"))
continue
# Check if operator exists in standard domains
if not defs.has(op_type, domain=domain):
custom_operators.append((op_type, domain, "unknown operator"))
if custom_operators:
print(f"Found {len(custom_operators)} custom operators:")
for op_type, domain, reason in custom_operators:
print(f" {op_type}@{domain} - {reason}")
# Group by domain
domains = set(domain for _, domain, _ in custom_operators)
print(f"\nCustom domains used: {domains}")
else:
print("No custom operators found - model uses only standard ONNX operators")
# Example usage (commented out)
# find_custom_operators("model_with_custom_ops.onnx")import onnx
from onnx import defs
def analyze_operator_evolution(op_name, domain=""):
"""Analyze how an operator has evolved across opset versions."""
# Get all schemas with history
all_schemas = defs.get_all_schemas_with_history()
# Find all versions of the specified operator
op_versions = []
for schema in all_schemas:
if schema.name == op_name and schema.domain == domain:
op_versions.append(schema)
if not op_versions:
print(f"Operator {op_name}@{domain} not found")
return
# Sort by version
op_versions.sort(key=lambda s: s.since_version)
print(f"Evolution of {op_name}@{domain or 'ONNX'}:")
for schema in op_versions:
print(f"\nVersion {schema.since_version}:")
print(f" Inputs: {len(schema.inputs)}")
print(f" Outputs: {len(schema.outputs)}")
print(f" Attributes: {len(schema.attributes)}")
# Show what changed
if len(op_versions) > 1 and schema != op_versions[0]:
prev_schema = op_versions[op_versions.index(schema) - 1]
# Check for new attributes
new_attrs = set(schema.attributes.keys()) - set(prev_schema.attributes.keys())
if new_attrs:
print(f" New attributes: {list(new_attrs)}")
# Check for removed attributes
removed_attrs = set(prev_schema.attributes.keys()) - set(schema.attributes.keys())
if removed_attrs:
print(f" Removed attributes: {list(removed_attrs)}")
# Analyze evolution of some operators
analyze_operator_evolution('Conv')
print("\n" + "="*50 + "\n")
analyze_operator_evolution('BatchNormalization')Install with Tessl CLI
npx tessl i tessl/pypi-onnx