0
# Model Composition
1
2
Functions for merging and composing multiple ONNX models or graphs, enabling modular model construction and complex pipeline creation. This module supports combining models in various ways to create larger, more complex computational graphs.
3
4
## Capabilities
5
6
### Model Merging
7
8
Combine multiple ONNX models into a single model with proper I/O mapping and name resolution.
9
10
```python { .api }
11
def merge_models(m1, m2, io_map=None, prefix1="", prefix2="",
12
doc_string="", producer_name="", producer_version=""):
13
"""
14
Merge two ONNX models into a single model.
15
16
Parameters:
17
- m1: First ModelProto to merge
18
- m2: Second ModelProto to merge
19
- io_map: List of tuples mapping outputs of m1 to inputs of m2
20
- prefix1: Prefix to add to all names in m1
21
- prefix2: Prefix to add to all names in m2
22
- doc_string: Documentation for the merged model
23
- producer_name: Producer name for the merged model
24
- producer_version: Producer version for the merged model
25
26
Returns:
27
ModelProto: Merged model combining both input models
28
29
Raises:
30
ValueError: If models cannot be merged due to incompatible types or shapes
31
"""
32
```
33
34
### Graph Merging
35
36
Combine computation graphs with flexible I/O mapping and name management.
37
38
```python { .api }
39
def merge_graphs(g1, g2, io_map=None, prefix1="", prefix2="",
40
inputs=None, outputs=None, name="merged_graph"):
41
"""
42
Merge two GraphProto objects into a single graph.
43
44
Parameters:
45
- g1: First GraphProto to merge
46
- g2: Second GraphProto to merge
47
- io_map: List of tuples connecting outputs of g1 to inputs of g2
48
- prefix1: Prefix for names in g1
49
- prefix2: Prefix for names in g2
50
- inputs: Input specifications for merged graph (auto-detected if None)
51
- outputs: Output specifications for merged graph (auto-detected if None)
52
- name: Name for the merged graph
53
54
Returns:
55
GraphProto: Merged computation graph
56
57
Raises:
58
ValueError: If graphs cannot be merged due to naming conflicts or type mismatches
59
"""
60
61
def check_overlapping_names(g1, g2, io_map=None):
62
"""
63
Check for overlapping names between two graphs.
64
65
Parameters:
66
- g1: First GraphProto to check
67
- g2: Second GraphProto to check
68
- io_map: I/O mapping that affects naming
69
70
Returns:
71
dict: Dictionary containing lists of overlapping names by category
72
73
Raises:
74
ValueError: If there are unresolvable name conflicts
75
"""
76
```
77
78
### Name Management
79
80
Utilities for managing names and avoiding conflicts in composed models.
81
82
```python { .api }
83
def add_prefix(model, prefix, rename_nodes=True, rename_edges=True,
84
rename_inputs=True, rename_outputs=True,
85
rename_initializers=True, rename_value_infos=True):
86
"""
87
Add prefix to all names in a model.
88
89
Parameters:
90
- model: ModelProto to modify
91
- prefix: Prefix string to add
92
- rename_nodes: Whether to rename node names
93
- rename_edges: Whether to rename edge (value) names
94
- rename_inputs: Whether to rename input names
95
- rename_outputs: Whether to rename output names
96
- rename_initializers: Whether to rename initializer names
97
- rename_value_infos: Whether to rename value info names
98
99
Returns:
100
ModelProto: Model with prefixed names
101
102
Raises:
103
ValueError: If prefix causes invalid names
104
"""
105
106
def add_prefix_graph(graph, prefix, rename_nodes=True, rename_edges=True,
107
rename_inputs=True, rename_outputs=True,
108
rename_initializers=True, rename_value_infos=True):
109
"""
110
Add prefix to all names in a graph.
111
112
Parameters:
113
- graph: GraphProto to modify
114
- prefix: Prefix string to add
115
- rename_nodes: Whether to rename node names
116
- rename_edges: Whether to rename edge (value) names
117
- rename_inputs: Whether to rename input names
118
- rename_outputs: Whether to rename output names
119
- rename_initializers: Whether to rename initializer names
120
- rename_value_infos: Whether to rename value info names
121
122
Returns:
123
GraphProto: Graph with prefixed names
124
125
Raises:
126
ValueError: If prefix causes invalid names
127
"""
128
```
129
130
### Dimension Manipulation
131
132
Utilities for modifying tensor dimensions in composed models.
133
134
```python { .api }
135
def expand_out_dim(model, dim_idx, incr=1):
136
"""
137
Expand output dimensions in a model.
138
139
Parameters:
140
- model: ModelProto to modify
141
- dim_idx: Index of dimension to expand
142
- incr: Amount to increment the dimension
143
144
Returns:
145
ModelProto: Model with expanded output dimensions
146
147
Raises:
148
ValueError: If dimension expansion is invalid
149
"""
150
151
def expand_out_dim_graph(graph, dim_idx, incr=1):
152
"""
153
Expand output dimensions in a graph.
154
155
Parameters:
156
- graph: GraphProto to modify
157
- dim_idx: Index of dimension to expand
158
- incr: Amount to increment the dimension
159
160
Returns:
161
GraphProto: Graph with expanded output dimensions
162
163
Raises:
164
ValueError: If dimension expansion is invalid
165
"""
166
```
167
168
## Usage Examples
169
170
### Sequential Model Composition
171
172
```python
173
import onnx
174
from onnx import compose, helper, TensorProto
175
import numpy as np
176
177
# Create first model (feature extractor)
178
def create_feature_extractor():
179
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 3, 224, 224])
180
features = helper.make_tensor_value_info('features', TensorProto.FLOAT, [1, 512])
181
182
# Simplified feature extraction (just a placeholder)
183
conv_weight = np.random.randn(512, 3, 224, 224).astype(np.float32)
184
conv_tensor = helper.make_tensor('conv_w', TensorProto.FLOAT,
185
conv_weight.shape, conv_weight)
186
187
conv_node = helper.make_node('Conv', ['input', 'conv_w'], ['features'],
188
kernel_shape=[224, 224])
189
190
graph = helper.make_graph([conv_node], 'feature_extractor',
191
[X], [features], [conv_tensor])
192
return helper.make_model(graph)
193
194
# Create second model (classifier)
195
def create_classifier():
196
features = helper.make_tensor_value_info('features', TensorProto.FLOAT, [1, 512])
197
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 10])
198
199
fc_weight = np.random.randn(512, 10).astype(np.float32)
200
fc_tensor = helper.make_tensor('fc_w', TensorProto.FLOAT,
201
fc_weight.shape, fc_weight)
202
203
fc_node = helper.make_node('MatMul', ['features', 'fc_w'], ['output'])
204
205
graph = helper.make_graph([fc_node], 'classifier',
206
[features], [output], [fc_tensor])
207
return helper.make_model(graph)
208
209
# Create and merge models
210
feature_model = create_feature_extractor()
211
classifier_model = create_classifier()
212
213
# Merge the models sequentially
214
# The 'features' output from model 1 connects to 'features' input of model 2
215
io_map = [('features', 'features')]
216
217
try:
218
merged_model = compose.merge_models(
219
feature_model,
220
classifier_model,
221
io_map=io_map,
222
producer_name="composed-model"
223
)
224
225
print("Models merged successfully!")
226
print(f"Input: {merged_model.graph.input[0].name}")
227
print(f"Output: {merged_model.graph.output[0].name}")
228
print(f"Number of nodes: {len(merged_model.graph.node)}")
229
230
# Save the merged model
231
onnx.save_model(merged_model, "merged_feature_classifier.onnx")
232
233
except ValueError as e:
234
print(f"Model merging failed: {e}")
235
```
236
237
### Parallel Model Composition
238
239
```python
240
import onnx
241
from onnx import compose, helper, TensorProto
242
import numpy as np
243
244
def create_branch_model(branch_name, input_dim, output_dim):
245
"""Create a simple branch model."""
246
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, input_dim])
247
Y = helper.make_tensor_value_info(f'output_{branch_name}', TensorProto.FLOAT, [1, output_dim])
248
249
weight = np.random.randn(input_dim, output_dim).astype(np.float32)
250
weight_tensor = helper.make_tensor(f'weight_{branch_name}', TensorProto.FLOAT,
251
weight.shape, weight)
252
253
matmul_node = helper.make_node('MatMul', ['input', f'weight_{branch_name}'],
254
[f'output_{branch_name}'])
255
256
graph = helper.make_graph([matmul_node], f'branch_{branch_name}',
257
[X], [Y], [weight_tensor])
258
return helper.make_model(graph)
259
260
# Create two parallel branches
261
branch1 = create_branch_model('A', 128, 64)
262
branch2 = create_branch_model('B', 128, 32)
263
264
# Add prefixes to avoid name conflicts
265
branch1_prefixed = compose.add_prefix(branch1, 'branch1_')
266
branch2_prefixed = compose.add_prefix(branch2, 'branch2_')
267
268
try:
269
# Merge with shared input (no I/O mapping needed for parallel composition)
270
parallel_model = compose.merge_models(
271
branch1_prefixed,
272
branch2_prefixed,
273
io_map=[], # No connections between branches
274
producer_name="parallel-branches"
275
)
276
277
print("Parallel branches merged successfully!")
278
print("Inputs:")
279
for inp in parallel_model.graph.input:
280
print(f" {inp.name}")
281
print("Outputs:")
282
for out in parallel_model.graph.output:
283
print(f" {out.name}")
284
285
onnx.save_model(parallel_model, "parallel_branches.onnx")
286
287
except ValueError as e:
288
print(f"Parallel composition failed: {e}")
289
```
290
291
### Complex Pipeline Composition
292
293
```python
294
import onnx
295
from onnx import compose, helper, TensorProto
296
import numpy as np
297
298
def create_preprocessing_model():
299
"""Create a preprocessing model."""
300
raw_input = helper.make_tensor_value_info('raw_data', TensorProto.FLOAT, [1, 1000])
301
processed = helper.make_tensor_value_info('processed_data', TensorProto.FLOAT, [1, 512])
302
303
# Normalization parameters
304
mean = np.zeros(1000, dtype=np.float32)
305
std = np.ones(1000, dtype=np.float32)
306
projection = np.random.randn(1000, 512).astype(np.float32)
307
308
mean_tensor = helper.make_tensor('mean', TensorProto.FLOAT, mean.shape, mean)
309
std_tensor = helper.make_tensor('std', TensorProto.FLOAT, std.shape, std)
310
proj_tensor = helper.make_tensor('projection', TensorProto.FLOAT,
311
projection.shape, projection)
312
313
# Normalize: (x - mean) / std
314
sub_node = helper.make_node('Sub', ['raw_data', 'mean'], ['centered'])
315
div_node = helper.make_node('Div', ['centered', 'std'], ['normalized'])
316
proj_node = helper.make_node('MatMul', ['normalized', 'projection'], ['processed_data'])
317
318
graph = helper.make_graph([sub_node, div_node, proj_node], 'preprocessor',
319
[raw_input], [processed],
320
[mean_tensor, std_tensor, proj_tensor])
321
return helper.make_model(graph)
322
323
def create_main_model():
324
"""Create the main processing model."""
325
processed = helper.make_tensor_value_info('processed_data', TensorProto.FLOAT, [1, 512])
326
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [1, 10])
327
328
weight = np.random.randn(512, 10).astype(np.float32)
329
weight_tensor = helper.make_tensor('main_weight', TensorProto.FLOAT,
330
weight.shape, weight)
331
332
main_node = helper.make_node('MatMul', ['processed_data', 'main_weight'], ['result'])
333
334
graph = helper.make_graph([main_node], 'main_processor',
335
[processed], [result], [weight_tensor])
336
return helper.make_model(graph)
337
338
def create_postprocessing_model():
339
"""Create a postprocessing model."""
340
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [1, 10])
341
final_output = helper.make_tensor_value_info('final_output', TensorProto.FLOAT, [1, 10])
342
343
# Apply softmax for final probabilities
344
softmax_node = helper.make_node('Softmax', ['result'], ['final_output'], axis=1)
345
346
graph = helper.make_graph([softmax_node], 'postprocessor',
347
[result], [final_output])
348
return helper.make_model(graph)
349
350
# Create individual models
351
prep_model = create_preprocessing_model()
352
main_model = create_main_model()
353
post_model = create_postprocessing_model()
354
355
try:
356
# First merge preprocessing and main processing
357
prep_main = compose.merge_models(
358
prep_model, main_model,
359
io_map=[('processed_data', 'processed_data')]
360
)
361
362
# Then merge with postprocessing
363
full_pipeline = compose.merge_models(
364
prep_main, post_model,
365
io_map=[('result', 'result')],
366
producer_name="full-pipeline"
367
)
368
369
print("Full pipeline created successfully!")
370
print(f"Pipeline: {full_pipeline.graph.input[0].name} -> {full_pipeline.graph.output[0].name}")
371
print(f"Total nodes: {len(full_pipeline.graph.node)}")
372
373
# Verify the pipeline structure
374
onnx.checker.check_model(full_pipeline)
375
print("Pipeline validation passed!")
376
377
onnx.save_model(full_pipeline, "full_pipeline.onnx")
378
379
except Exception as e:
380
print(f"Pipeline creation failed: {e}")
381
```
382
383
### Name Conflict Resolution
384
385
```python
386
import onnx
387
from onnx import compose, helper, TensorProto
388
import numpy as np
389
390
def create_model_with_conflicts():
391
"""Create two models that would have naming conflicts."""
392
393
# Both models use the same internal names
394
def create_simple_model(suffix=""):
395
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 10])
396
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 5])
397
398
weight = np.random.randn(10, 5).astype(np.float32)
399
weight_tensor = helper.make_tensor('weight', TensorProto.FLOAT,
400
weight.shape, weight)
401
402
node = helper.make_node('MatMul', ['input', 'weight'], ['temp'])
403
relu_node = helper.make_node('Relu', ['temp'], ['output'])
404
405
graph = helper.make_graph([node, relu_node], f'model{suffix}',
406
[X], [Y], [weight_tensor])
407
return helper.make_model(graph)
408
409
model1 = create_simple_model("1")
410
model2 = create_simple_model("2")
411
412
# Check for naming conflicts
413
try:
414
conflicts = compose.check_overlapping_names(model1.graph, model2.graph)
415
if conflicts:
416
print("Found naming conflicts:")
417
for category, names in conflicts.items():
418
if names:
419
print(f" {category}: {names}")
420
421
# Resolve conflicts using prefixes
422
model1_prefixed = compose.add_prefix(model1, "first_")
423
model2_prefixed = compose.add_prefix(model2, "second_")
424
425
# Now merge safely
426
merged = compose.merge_models(
427
model1_prefixed, model2_prefixed,
428
io_map=[('first_output', 'second_input')], # Connect output to input
429
producer_name="conflict-resolved"
430
)
431
432
print("Models merged successfully after conflict resolution!")
433
print("Final model structure:")
434
print(f" Inputs: {[inp.name for inp in merged.graph.input]}")
435
print(f" Outputs: {[out.name for out in merged.graph.output]}")
436
print(f" Nodes: {len(merged.graph.node)}")
437
438
except Exception as e:
439
print(f"Conflict resolution failed: {e}")
440
441
# Run the conflict resolution example
442
create_model_with_conflicts()
443
```