or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

backend-integration.mdindex.mdmodel-composition.mdmodel-construction.mdmodel-hub.mdmodel-io.mdmodel-validation.mdnumpy-integration.mdoperator-definitions.mdreference-implementation.mdshape-inference.mdtext-processing.mdversion-conversion.md

shape-inference.mddocs/

0

# Shape Inference

1

2

Automatic shape and type inference for ONNX model graphs, enabling optimization and validation of tensor shapes throughout the computation graph. Shape inference is essential for model optimization, validation, and runtime preparation.

3

4

## Capabilities

5

6

### Model Shape Inference

7

8

Infer shapes and types for all values in a model's computation graph.

9

10

```python { .api }

11

def infer_shapes(

12

model: ModelProto | bytes,

13

check_type: bool = False,

14

strict_mode: bool = False,

15

data_prop: bool = False,

16

) -> ModelProto:

17

"""

18

Apply shape inference to the provided ModelProto.

19

20

Inferred shapes are added to the value_info field of the graph.

21

22

If the inferred values conflict with values already provided in the

23

graph, that means that the provided values are invalid (or there is a

24

bug in shape inference), and the result is unspecified.

25

26

Parameters:

27

- model: Union[ModelProto, bytes], bool, bool, bool) -> ModelProto

28

- check_type: Checks the type-equality for input and output

29

- strict_mode: Stricter shape inference, it will throw errors if any;

30

Otherwise, simply stop if any error

31

- data_prop: Enables data propagation for limited operators to perform shape computation

32

33

Returns:

34

ModelProto: model with inferred shape information

35

"""

36

37

def infer_shapes_path(

38

model_path: str | os.PathLike,

39

output_path: str | os.PathLike = "",

40

check_type: bool = False,

41

strict_mode: bool = False,

42

data_prop: bool = False,

43

) -> None:

44

"""

45

Take model path for shape_inference same as infer_shape; it support >2GB models

46

Directly output the inferred model to the output_path; Default is the original model path

47

"""

48

```

49

50

### Node-Level Shape Inference

51

52

Perform shape inference for individual nodes and functions.

53

54

```python { .api }

55

def infer_node_outputs(

56

schema: onnx.defs.OpSchema,

57

node: onnx.NodeProto,

58

input_types: dict[str, onnx.TypeProto],

59

input_data: dict[str, onnx.TensorProto] | None = None,

60

input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,

61

opset_imports: list[onnx.OperatorSetIdProto] | None = None,

62

ir_version: int = onnx.IR_VERSION,

63

) -> dict[str, onnx.TypeProto]:

64

"""

65

Infer output types for a single node.

66

67

Parameters:

68

- schema: OpSchema for the node's operator

69

- node: NodeProto to infer outputs for

70

- input_types: dict mapping input names to TypeProto for node inputs

71

- input_data: Optional input data for data-dependent inference

72

- input_sparse_data: Optional sparse input data

73

- opset_imports: Optional opset imports

74

- ir_version: IR version to use

75

76

Returns:

77

dict[str, onnx.TypeProto]: Inferred output types

78

79

Raises:

80

InferenceError: If inference fails for the node

81

"""

82

83

def infer_function_output_types(

84

function: FunctionProto,

85

input_types: Sequence[TypeProto],

86

attributes: Sequence[AttributeProto],

87

) -> list[TypeProto]:

88

"""

89

Apply type-and-shape-inference to given function body, with given input types

90

and given input attribute values.

91

"""

92

```

93

94

### Shape Inference Exceptions

95

96

Exception types for shape inference errors.

97

98

```python { .api }

99

class InferenceError(Exception):

100

"""

101

Exception raised when shape inference fails.

102

103

Contains detailed information about why inference failed,

104

including the specific node or operation that caused the error.

105

"""

106

```

107

108

## Usage Examples

109

110

### Basic Shape Inference

111

112

```python

113

import onnx

114

from onnx import shape_inference

115

116

# Load a model without shape information

117

model = onnx.load_model("model_without_shapes.onnx")

118

119

try:

120

# Perform shape inference

121

inferred_model = shape_inference.infer_shapes(model)

122

123

# Save model with inferred shapes

124

onnx.save_model(inferred_model, "model_with_shapes.onnx")

125

print("Shape inference completed successfully!")

126

127

except shape_inference.InferenceError as e:

128

print(f"Shape inference failed: {e}")

129

```

130

131

### Advanced Shape Inference Options

132

133

```python

134

import onnx

135

from onnx import shape_inference

136

137

model = onnx.load_model("complex_model.onnx")

138

139

try:

140

# Perform shape inference with type checking and data propagation

141

inferred_model = shape_inference.infer_shapes(

142

model,

143

check_type=True, # Check for type compatibility

144

strict_mode=True, # Apply strict inference rules

145

data_prop=True # Enable data value propagation

146

)

147

148

print("Advanced shape inference completed!")

149

150

# Check the inferred shapes

151

for value_info in inferred_model.graph.value_info:

152

print(f"Value: {value_info.name}")

153

if value_info.type.HasField('tensor_type'):

154

tensor_type = value_info.type.tensor_type

155

shape = [dim.dim_value if dim.HasField('dim_value') else dim.dim_param

156

for dim in tensor_type.shape.dim]

157

print(f" Shape: {shape}")

158

print(f" Type: {tensor_type.elem_type}")

159

160

except shape_inference.InferenceError as e:

161

print(f"Shape inference failed: {e}")

162

```

163

164

### Shape Inference for Large Models

165

166

```python

167

import onnx

168

from onnx import shape_inference

169

170

# For models larger than 2GB, use the path-based inference

171

try:

172

shape_inference.infer_shapes_path(

173

model_path="large_model.onnx",

174

output_path="large_model_with_shapes.onnx",

175

check_type=True

176

)

177

print("Shape inference for large model completed!")

178

179

except shape_inference.InferenceError as e:

180

print(f"Shape inference failed: {e}")

181

```

182

183

### Node-Level Shape Inference

184

185

```python

186

import onnx

187

from onnx import helper, shape_inference, defs, TensorProto

188

189

# Create input type information

190

input_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [1, 3, 224, 224])

191

192

# Create a convolution node

193

conv_node = helper.make_node(

194

'Conv',

195

inputs=['input', 'weight'],

196

outputs=['output'],

197

kernel_shape=[3, 3],

198

pads=[1, 1, 1, 1],

199

strides=[1, 1]

200

)

201

202

# Get the operator schema

203

conv_schema = defs.get_schema('Conv')

204

205

# Create weight type information

206

weight_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [64, 3, 3, 3])

207

208

try:

209

# Infer output types for the node

210

output_types = shape_inference.infer_node_outputs(

211

conv_schema,

212

conv_node,

213

[input_type, weight_type]

214

)

215

216

print(f"Inferred output shape for Conv node:")

217

for i, output_type in enumerate(output_types):

218

if output_type.HasField('tensor_type'):

219

shape = [dim.dim_value or dim.dim_param

220

for dim in output_type.tensor_type.shape.dim]

221

print(f" Output {i}: {shape}")

222

223

except shape_inference.InferenceError as e:

224

print(f"Node-level inference failed: {e}")

225

```

226

227

### Debugging Shape Inference Issues

228

229

```python

230

import onnx

231

from onnx import shape_inference

232

233

def debug_shape_inference(model_path):

234

"""Debug shape inference issues by examining the model structure."""

235

236

model = onnx.load_model(model_path)

237

238

# Check if model has input shapes defined

239

print("Input information:")

240

for input_info in model.graph.input:

241

print(f" {input_info.name}: {input_info.type}")

242

243

# Check for nodes that might cause issues

244

print("\nNodes in graph:")

245

for i, node in enumerate(model.graph.node):

246

print(f" {i}: {node.op_type} ({node.name or 'unnamed'})")

247

print(f" Inputs: {list(node.input)}")

248

print(f" Outputs: {list(node.output)}")

249

250

try:

251

# Attempt shape inference

252

inferred_model = shape_inference.infer_shapes(model, check_type=True)

253

print("\nShape inference successful!")

254

255

# Show inferred shapes

256

print("Inferred value information:")

257

for value_info in inferred_model.graph.value_info:

258

print(f" {value_info.name}: {value_info.type}")

259

260

except shape_inference.InferenceError as e:

261

print(f"\nShape inference failed: {e}")

262

print("Common causes:")

263

print("- Missing input shape information")

264

print("- Unsupported operators")

265

print("- Type mismatches between connected nodes")

266

print("- Missing initializer tensors")

267

268

# Debug a problematic model

269

# debug_shape_inference("problematic_model.onnx")

270

```

271

272

### Integrating Shape Inference with Model Construction

273

274

```python

275

import onnx

276

from onnx import helper, shape_inference, TensorProto

277

import numpy as np

278

279

def create_model_with_inference():

280

"""Create a model and automatically infer shapes."""

281

282

# Define input (without complete shape information)

283

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 784])

284

Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 10])

285

286

# Create weight matrices

287

W1 = np.random.randn(784, 128).astype(np.float32)

288

W2 = np.random.randn(128, 10).astype(np.float32)

289

290

W1_tensor = helper.make_tensor('W1', TensorProto.FLOAT, W1.shape, W1)

291

W2_tensor = helper.make_tensor('W2', TensorProto.FLOAT, W2.shape, W2)

292

293

# Create computation nodes

294

matmul1 = helper.make_node('MatMul', ['X', 'W1'], ['hidden'])

295

relu = helper.make_node('Relu', ['hidden'], ['hidden_relu'])

296

matmul2 = helper.make_node('MatMul', ['hidden_relu', 'W2'], ['Y'])

297

298

# Create graph

299

graph = helper.make_graph(

300

[matmul1, relu, matmul2],

301

'mlp_model',

302

[X], [Y],

303

[W1_tensor, W2_tensor]

304

)

305

306

# Create model

307

model = helper.make_model(graph)

308

309

try:

310

# Perform shape inference to fill in intermediate shapes

311

inferred_model = shape_inference.infer_shapes(model)

312

313

print("Model created with shape inference:")

314

for value_info in inferred_model.graph.value_info:

315

if value_info.type.HasField('tensor_type'):

316

shape = [dim.dim_value if dim.HasField('dim_value') else '?'

317

for dim in value_info.type.tensor_type.shape.dim]

318

print(f" {value_info.name}: {shape}")

319

320

return inferred_model

321

322

except shape_inference.InferenceError as e:

323

print(f"Shape inference failed: {e}")

324

return model

325

326

# Create model with automatic shape inference

327

model = create_model_with_inference()

328

```