Open Neural Network Exchange for AI model interoperability and machine learning frameworks
—
Functions for merging and composing multiple ONNX models or graphs, enabling modular model construction and complex pipeline creation. This module supports combining models in various ways to create larger, more complex computational graphs.
Combine multiple ONNX models into a single model with proper I/O mapping and name resolution.
def merge_models(m1, m2, io_map=None, prefix1="", prefix2="",
doc_string="", producer_name="", producer_version=""):
"""
Merge two ONNX models into a single model.
Parameters:
- m1: First ModelProto to merge
- m2: Second ModelProto to merge
- io_map: List of tuples mapping outputs of m1 to inputs of m2
- prefix1: Prefix to add to all names in m1
- prefix2: Prefix to add to all names in m2
- doc_string: Documentation for the merged model
- producer_name: Producer name for the merged model
- producer_version: Producer version for the merged model
Returns:
ModelProto: Merged model combining both input models
Raises:
ValueError: If models cannot be merged due to incompatible types or shapes
"""Combine computation graphs with flexible I/O mapping and name management.
def merge_graphs(g1, g2, io_map=None, prefix1="", prefix2="",
inputs=None, outputs=None, name="merged_graph"):
"""
Merge two GraphProto objects into a single graph.
Parameters:
- g1: First GraphProto to merge
- g2: Second GraphProto to merge
- io_map: List of tuples connecting outputs of g1 to inputs of g2
- prefix1: Prefix for names in g1
- prefix2: Prefix for names in g2
- inputs: Input specifications for merged graph (auto-detected if None)
- outputs: Output specifications for merged graph (auto-detected if None)
- name: Name for the merged graph
Returns:
GraphProto: Merged computation graph
Raises:
ValueError: If graphs cannot be merged due to naming conflicts or type mismatches
"""
def check_overlapping_names(g1, g2, io_map=None):
"""
Check for overlapping names between two graphs.
Parameters:
- g1: First GraphProto to check
- g2: Second GraphProto to check
- io_map: I/O mapping that affects naming
Returns:
dict: Dictionary containing lists of overlapping names by category
Raises:
ValueError: If there are unresolvable name conflicts
"""Utilities for managing names and avoiding conflicts in composed models.
def add_prefix(model, prefix, rename_nodes=True, rename_edges=True,
rename_inputs=True, rename_outputs=True,
rename_initializers=True, rename_value_infos=True):
"""
Add prefix to all names in a model.
Parameters:
- model: ModelProto to modify
- prefix: Prefix string to add
- rename_nodes: Whether to rename node names
- rename_edges: Whether to rename edge (value) names
- rename_inputs: Whether to rename input names
- rename_outputs: Whether to rename output names
- rename_initializers: Whether to rename initializer names
- rename_value_infos: Whether to rename value info names
Returns:
ModelProto: Model with prefixed names
Raises:
ValueError: If prefix causes invalid names
"""
def add_prefix_graph(graph, prefix, rename_nodes=True, rename_edges=True,
rename_inputs=True, rename_outputs=True,
rename_initializers=True, rename_value_infos=True):
"""
Add prefix to all names in a graph.
Parameters:
- graph: GraphProto to modify
- prefix: Prefix string to add
- rename_nodes: Whether to rename node names
- rename_edges: Whether to rename edge (value) names
- rename_inputs: Whether to rename input names
- rename_outputs: Whether to rename output names
- rename_initializers: Whether to rename initializer names
- rename_value_infos: Whether to rename value info names
Returns:
GraphProto: Graph with prefixed names
Raises:
ValueError: If prefix causes invalid names
"""Utilities for modifying tensor dimensions in composed models.
def expand_out_dim(model, dim_idx, incr=1):
"""
Expand output dimensions in a model.
Parameters:
- model: ModelProto to modify
- dim_idx: Index of dimension to expand
- incr: Amount to increment the dimension
Returns:
ModelProto: Model with expanded output dimensions
Raises:
ValueError: If dimension expansion is invalid
"""
def expand_out_dim_graph(graph, dim_idx, incr=1):
"""
Expand output dimensions in a graph.
Parameters:
- graph: GraphProto to modify
- dim_idx: Index of dimension to expand
- incr: Amount to increment the dimension
Returns:
GraphProto: Graph with expanded output dimensions
Raises:
ValueError: If dimension expansion is invalid
"""import onnx
from onnx import compose, helper, TensorProto
import numpy as np
# Create first model (feature extractor)
def create_feature_extractor():
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 3, 224, 224])
features = helper.make_tensor_value_info('features', TensorProto.FLOAT, [1, 512])
# Simplified feature extraction (just a placeholder)
conv_weight = np.random.randn(512, 3, 224, 224).astype(np.float32)
conv_tensor = helper.make_tensor('conv_w', TensorProto.FLOAT,
conv_weight.shape, conv_weight)
conv_node = helper.make_node('Conv', ['input', 'conv_w'], ['features'],
kernel_shape=[224, 224])
graph = helper.make_graph([conv_node], 'feature_extractor',
[X], [features], [conv_tensor])
return helper.make_model(graph)
# Create second model (classifier)
def create_classifier():
features = helper.make_tensor_value_info('features', TensorProto.FLOAT, [1, 512])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 10])
fc_weight = np.random.randn(512, 10).astype(np.float32)
fc_tensor = helper.make_tensor('fc_w', TensorProto.FLOAT,
fc_weight.shape, fc_weight)
fc_node = helper.make_node('MatMul', ['features', 'fc_w'], ['output'])
graph = helper.make_graph([fc_node], 'classifier',
[features], [output], [fc_tensor])
return helper.make_model(graph)
# Create and merge models
feature_model = create_feature_extractor()
classifier_model = create_classifier()
# Merge the models sequentially
# The 'features' output from model 1 connects to 'features' input of model 2
io_map = [('features', 'features')]
try:
merged_model = compose.merge_models(
feature_model,
classifier_model,
io_map=io_map,
producer_name="composed-model"
)
print("Models merged successfully!")
print(f"Input: {merged_model.graph.input[0].name}")
print(f"Output: {merged_model.graph.output[0].name}")
print(f"Number of nodes: {len(merged_model.graph.node)}")
# Save the merged model
onnx.save_model(merged_model, "merged_feature_classifier.onnx")
except ValueError as e:
print(f"Model merging failed: {e}")import onnx
from onnx import compose, helper, TensorProto
import numpy as np
def create_branch_model(branch_name, input_dim, output_dim):
"""Create a simple branch model."""
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, input_dim])
Y = helper.make_tensor_value_info(f'output_{branch_name}', TensorProto.FLOAT, [1, output_dim])
weight = np.random.randn(input_dim, output_dim).astype(np.float32)
weight_tensor = helper.make_tensor(f'weight_{branch_name}', TensorProto.FLOAT,
weight.shape, weight)
matmul_node = helper.make_node('MatMul', ['input', f'weight_{branch_name}'],
[f'output_{branch_name}'])
graph = helper.make_graph([matmul_node], f'branch_{branch_name}',
[X], [Y], [weight_tensor])
return helper.make_model(graph)
# Create two parallel branches
branch1 = create_branch_model('A', 128, 64)
branch2 = create_branch_model('B', 128, 32)
# Add prefixes to avoid name conflicts
branch1_prefixed = compose.add_prefix(branch1, 'branch1_')
branch2_prefixed = compose.add_prefix(branch2, 'branch2_')
try:
# Merge with shared input (no I/O mapping needed for parallel composition)
parallel_model = compose.merge_models(
branch1_prefixed,
branch2_prefixed,
io_map=[], # No connections between branches
producer_name="parallel-branches"
)
print("Parallel branches merged successfully!")
print("Inputs:")
for inp in parallel_model.graph.input:
print(f" {inp.name}")
print("Outputs:")
for out in parallel_model.graph.output:
print(f" {out.name}")
onnx.save_model(parallel_model, "parallel_branches.onnx")
except ValueError as e:
print(f"Parallel composition failed: {e}")import onnx
from onnx import compose, helper, TensorProto
import numpy as np
def create_preprocessing_model():
"""Create a preprocessing model."""
raw_input = helper.make_tensor_value_info('raw_data', TensorProto.FLOAT, [1, 1000])
processed = helper.make_tensor_value_info('processed_data', TensorProto.FLOAT, [1, 512])
# Normalization parameters
mean = np.zeros(1000, dtype=np.float32)
std = np.ones(1000, dtype=np.float32)
projection = np.random.randn(1000, 512).astype(np.float32)
mean_tensor = helper.make_tensor('mean', TensorProto.FLOAT, mean.shape, mean)
std_tensor = helper.make_tensor('std', TensorProto.FLOAT, std.shape, std)
proj_tensor = helper.make_tensor('projection', TensorProto.FLOAT,
projection.shape, projection)
# Normalize: (x - mean) / std
sub_node = helper.make_node('Sub', ['raw_data', 'mean'], ['centered'])
div_node = helper.make_node('Div', ['centered', 'std'], ['normalized'])
proj_node = helper.make_node('MatMul', ['normalized', 'projection'], ['processed_data'])
graph = helper.make_graph([sub_node, div_node, proj_node], 'preprocessor',
[raw_input], [processed],
[mean_tensor, std_tensor, proj_tensor])
return helper.make_model(graph)
def create_main_model():
"""Create the main processing model."""
processed = helper.make_tensor_value_info('processed_data', TensorProto.FLOAT, [1, 512])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [1, 10])
weight = np.random.randn(512, 10).astype(np.float32)
weight_tensor = helper.make_tensor('main_weight', TensorProto.FLOAT,
weight.shape, weight)
main_node = helper.make_node('MatMul', ['processed_data', 'main_weight'], ['result'])
graph = helper.make_graph([main_node], 'main_processor',
[processed], [result], [weight_tensor])
return helper.make_model(graph)
def create_postprocessing_model():
"""Create a postprocessing model."""
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [1, 10])
final_output = helper.make_tensor_value_info('final_output', TensorProto.FLOAT, [1, 10])
# Apply softmax for final probabilities
softmax_node = helper.make_node('Softmax', ['result'], ['final_output'], axis=1)
graph = helper.make_graph([softmax_node], 'postprocessor',
[result], [final_output])
return helper.make_model(graph)
# Create individual models
prep_model = create_preprocessing_model()
main_model = create_main_model()
post_model = create_postprocessing_model()
try:
# First merge preprocessing and main processing
prep_main = compose.merge_models(
prep_model, main_model,
io_map=[('processed_data', 'processed_data')]
)
# Then merge with postprocessing
full_pipeline = compose.merge_models(
prep_main, post_model,
io_map=[('result', 'result')],
producer_name="full-pipeline"
)
print("Full pipeline created successfully!")
print(f"Pipeline: {full_pipeline.graph.input[0].name} -> {full_pipeline.graph.output[0].name}")
print(f"Total nodes: {len(full_pipeline.graph.node)}")
# Verify the pipeline structure
onnx.checker.check_model(full_pipeline)
print("Pipeline validation passed!")
onnx.save_model(full_pipeline, "full_pipeline.onnx")
except Exception as e:
print(f"Pipeline creation failed: {e}")import onnx
from onnx import compose, helper, TensorProto
import numpy as np
def create_model_with_conflicts():
"""Create two models that would have naming conflicts."""
# Both models use the same internal names
def create_simple_model(suffix=""):
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 10])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 5])
weight = np.random.randn(10, 5).astype(np.float32)
weight_tensor = helper.make_tensor('weight', TensorProto.FLOAT,
weight.shape, weight)
node = helper.make_node('MatMul', ['input', 'weight'], ['temp'])
relu_node = helper.make_node('Relu', ['temp'], ['output'])
graph = helper.make_graph([node, relu_node], f'model{suffix}',
[X], [Y], [weight_tensor])
return helper.make_model(graph)
model1 = create_simple_model("1")
model2 = create_simple_model("2")
# Check for naming conflicts
try:
conflicts = compose.check_overlapping_names(model1.graph, model2.graph)
if conflicts:
print("Found naming conflicts:")
for category, names in conflicts.items():
if names:
print(f" {category}: {names}")
# Resolve conflicts using prefixes
model1_prefixed = compose.add_prefix(model1, "first_")
model2_prefixed = compose.add_prefix(model2, "second_")
# Now merge safely
merged = compose.merge_models(
model1_prefixed, model2_prefixed,
io_map=[('first_output', 'second_input')], # Connect output to input
producer_name="conflict-resolved"
)
print("Models merged successfully after conflict resolution!")
print("Final model structure:")
print(f" Inputs: {[inp.name for inp in merged.graph.input]}")
print(f" Outputs: {[out.name for out in merged.graph.output]}")
print(f" Nodes: {len(merged.graph.node)}")
except Exception as e:
print(f"Conflict resolution failed: {e}")
# Run the conflict resolution example
create_model_with_conflicts()Install with Tessl CLI
npx tessl i tessl/pypi-onnx