0
# Shape Inference
1
2
Automatic shape and type inference for ONNX model graphs, enabling optimization and validation of tensor shapes throughout the computation graph. Shape inference is essential for model optimization, validation, and runtime preparation.
3
4
## Capabilities
5
6
### Model Shape Inference
7
8
Infer shapes and types for all values in a model's computation graph.
9
10
```python { .api }
11
def infer_shapes(
12
model: ModelProto | bytes,
13
check_type: bool = False,
14
strict_mode: bool = False,
15
data_prop: bool = False,
16
) -> ModelProto:
17
"""
18
Apply shape inference to the provided ModelProto.
19
20
Inferred shapes are added to the value_info field of the graph.
21
22
If the inferred values conflict with values already provided in the
23
graph, that means that the provided values are invalid (or there is a
24
bug in shape inference), and the result is unspecified.
25
26
Parameters:
27
- model: Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
28
- check_type: Checks the type-equality for input and output
29
- strict_mode: Stricter shape inference, it will throw errors if any;
30
Otherwise, simply stop if any error
31
- data_prop: Enables data propagation for limited operators to perform shape computation
32
33
Returns:
34
ModelProto: model with inferred shape information
35
"""
36
37
def infer_shapes_path(
38
model_path: str | os.PathLike,
39
output_path: str | os.PathLike = "",
40
check_type: bool = False,
41
strict_mode: bool = False,
42
data_prop: bool = False,
43
) -> None:
44
"""
45
Take model path for shape_inference same as infer_shape; it support >2GB models
46
Directly output the inferred model to the output_path; Default is the original model path
47
"""
48
```
49
50
### Node-Level Shape Inference
51
52
Perform shape inference for individual nodes and functions.
53
54
```python { .api }
55
def infer_node_outputs(
56
schema: onnx.defs.OpSchema,
57
node: onnx.NodeProto,
58
input_types: dict[str, onnx.TypeProto],
59
input_data: dict[str, onnx.TensorProto] | None = None,
60
input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
61
opset_imports: list[onnx.OperatorSetIdProto] | None = None,
62
ir_version: int = onnx.IR_VERSION,
63
) -> dict[str, onnx.TypeProto]:
64
"""
65
Infer output types for a single node.
66
67
Parameters:
68
- schema: OpSchema for the node's operator
69
- node: NodeProto to infer outputs for
70
- input_types: dict mapping input names to TypeProto for node inputs
71
- input_data: Optional input data for data-dependent inference
72
- input_sparse_data: Optional sparse input data
73
- opset_imports: Optional opset imports
74
- ir_version: IR version to use
75
76
Returns:
77
dict[str, onnx.TypeProto]: Inferred output types
78
79
Raises:
80
InferenceError: If inference fails for the node
81
"""
82
83
def infer_function_output_types(
84
function: FunctionProto,
85
input_types: Sequence[TypeProto],
86
attributes: Sequence[AttributeProto],
87
) -> list[TypeProto]:
88
"""
89
Apply type-and-shape-inference to given function body, with given input types
90
and given input attribute values.
91
"""
92
```
93
94
### Shape Inference Exceptions
95
96
Exception types for shape inference errors.
97
98
```python { .api }
99
class InferenceError(Exception):
100
"""
101
Exception raised when shape inference fails.
102
103
Contains detailed information about why inference failed,
104
including the specific node or operation that caused the error.
105
"""
106
```
107
108
## Usage Examples
109
110
### Basic Shape Inference
111
112
```python
113
import onnx
114
from onnx import shape_inference
115
116
# Load a model without shape information
117
model = onnx.load_model("model_without_shapes.onnx")
118
119
try:
120
# Perform shape inference
121
inferred_model = shape_inference.infer_shapes(model)
122
123
# Save model with inferred shapes
124
onnx.save_model(inferred_model, "model_with_shapes.onnx")
125
print("Shape inference completed successfully!")
126
127
except shape_inference.InferenceError as e:
128
print(f"Shape inference failed: {e}")
129
```
130
131
### Advanced Shape Inference Options
132
133
```python
134
import onnx
135
from onnx import shape_inference
136
137
model = onnx.load_model("complex_model.onnx")
138
139
try:
140
# Perform shape inference with type checking and data propagation
141
inferred_model = shape_inference.infer_shapes(
142
model,
143
check_type=True, # Check for type compatibility
144
strict_mode=True, # Apply strict inference rules
145
data_prop=True # Enable data value propagation
146
)
147
148
print("Advanced shape inference completed!")
149
150
# Check the inferred shapes
151
for value_info in inferred_model.graph.value_info:
152
print(f"Value: {value_info.name}")
153
if value_info.type.HasField('tensor_type'):
154
tensor_type = value_info.type.tensor_type
155
shape = [dim.dim_value if dim.HasField('dim_value') else dim.dim_param
156
for dim in tensor_type.shape.dim]
157
print(f" Shape: {shape}")
158
print(f" Type: {tensor_type.elem_type}")
159
160
except shape_inference.InferenceError as e:
161
print(f"Shape inference failed: {e}")
162
```
163
164
### Shape Inference for Large Models
165
166
```python
167
import onnx
168
from onnx import shape_inference
169
170
# For models larger than 2GB, use the path-based inference
171
try:
172
shape_inference.infer_shapes_path(
173
model_path="large_model.onnx",
174
output_path="large_model_with_shapes.onnx",
175
check_type=True
176
)
177
print("Shape inference for large model completed!")
178
179
except shape_inference.InferenceError as e:
180
print(f"Shape inference failed: {e}")
181
```
182
183
### Node-Level Shape Inference
184
185
```python
186
import onnx
187
from onnx import helper, shape_inference, defs, TensorProto
188
189
# Create input type information
190
input_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [1, 3, 224, 224])
191
192
# Create a convolution node
193
conv_node = helper.make_node(
194
'Conv',
195
inputs=['input', 'weight'],
196
outputs=['output'],
197
kernel_shape=[3, 3],
198
pads=[1, 1, 1, 1],
199
strides=[1, 1]
200
)
201
202
# Get the operator schema
203
conv_schema = defs.get_schema('Conv')
204
205
# Create weight type information
206
weight_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [64, 3, 3, 3])
207
208
try:
209
# Infer output types for the node
210
output_types = shape_inference.infer_node_outputs(
211
conv_schema,
212
conv_node,
213
[input_type, weight_type]
214
)
215
216
print(f"Inferred output shape for Conv node:")
217
for i, output_type in enumerate(output_types):
218
if output_type.HasField('tensor_type'):
219
shape = [dim.dim_value or dim.dim_param
220
for dim in output_type.tensor_type.shape.dim]
221
print(f" Output {i}: {shape}")
222
223
except shape_inference.InferenceError as e:
224
print(f"Node-level inference failed: {e}")
225
```
226
227
### Debugging Shape Inference Issues
228
229
```python
230
import onnx
231
from onnx import shape_inference
232
233
def debug_shape_inference(model_path):
234
"""Debug shape inference issues by examining the model structure."""
235
236
model = onnx.load_model(model_path)
237
238
# Check if model has input shapes defined
239
print("Input information:")
240
for input_info in model.graph.input:
241
print(f" {input_info.name}: {input_info.type}")
242
243
# Check for nodes that might cause issues
244
print("\nNodes in graph:")
245
for i, node in enumerate(model.graph.node):
246
print(f" {i}: {node.op_type} ({node.name or 'unnamed'})")
247
print(f" Inputs: {list(node.input)}")
248
print(f" Outputs: {list(node.output)}")
249
250
try:
251
# Attempt shape inference
252
inferred_model = shape_inference.infer_shapes(model, check_type=True)
253
print("\nShape inference successful!")
254
255
# Show inferred shapes
256
print("Inferred value information:")
257
for value_info in inferred_model.graph.value_info:
258
print(f" {value_info.name}: {value_info.type}")
259
260
except shape_inference.InferenceError as e:
261
print(f"\nShape inference failed: {e}")
262
print("Common causes:")
263
print("- Missing input shape information")
264
print("- Unsupported operators")
265
print("- Type mismatches between connected nodes")
266
print("- Missing initializer tensors")
267
268
# Debug a problematic model
269
# debug_shape_inference("problematic_model.onnx")
270
```
271
272
### Integrating Shape Inference with Model Construction
273
274
```python
275
import onnx
276
from onnx import helper, shape_inference, TensorProto
277
import numpy as np
278
279
def create_model_with_inference():
280
"""Create a model and automatically infer shapes."""
281
282
# Define input (without complete shape information)
283
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 784])
284
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 10])
285
286
# Create weight matrices
287
W1 = np.random.randn(784, 128).astype(np.float32)
288
W2 = np.random.randn(128, 10).astype(np.float32)
289
290
W1_tensor = helper.make_tensor('W1', TensorProto.FLOAT, W1.shape, W1)
291
W2_tensor = helper.make_tensor('W2', TensorProto.FLOAT, W2.shape, W2)
292
293
# Create computation nodes
294
matmul1 = helper.make_node('MatMul', ['X', 'W1'], ['hidden'])
295
relu = helper.make_node('Relu', ['hidden'], ['hidden_relu'])
296
matmul2 = helper.make_node('MatMul', ['hidden_relu', 'W2'], ['Y'])
297
298
# Create graph
299
graph = helper.make_graph(
300
[matmul1, relu, matmul2],
301
'mlp_model',
302
[X], [Y],
303
[W1_tensor, W2_tensor]
304
)
305
306
# Create model
307
model = helper.make_model(graph)
308
309
try:
310
# Perform shape inference to fill in intermediate shapes
311
inferred_model = shape_inference.infer_shapes(model)
312
313
print("Model created with shape inference:")
314
for value_info in inferred_model.graph.value_info:
315
if value_info.type.HasField('tensor_type'):
316
shape = [dim.dim_value if dim.HasField('dim_value') else '?'
317
for dim in value_info.type.tensor_type.shape.dim]
318
print(f" {value_info.name}: {shape}")
319
320
return inferred_model
321
322
except shape_inference.InferenceError as e:
323
print(f"Shape inference failed: {e}")
324
return model
325
326
# Create model with automatic shape inference
327
model = create_model_with_inference()
328
```