0
# Advanced Features
1
2
JIT compilation, model export, graph transformations, quantization, and deployment utilities for optimizing and deploying PyTorch models in production environments.
3
4
## Capabilities
5
6
### JIT Compilation (torch.jit)
7
8
TorchScript compilation for model optimization and deployment.
9
10
```python { .api }
11
def jit.script(obj, optimize=None, _frames_up=0, _rcb=None):
12
"""
13
Compile Python code to TorchScript.
14
15
Parameters:
16
- obj: Function, method, or class to compile
17
- optimize: Whether to apply optimizations
18
19
Returns:
20
ScriptModule or ScriptFunction
21
"""
22
23
def jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=None):
24
"""
25
Trace function execution to create TorchScript.
26
27
Parameters:
28
- func: Function or module to trace
29
- example_inputs: Example inputs for tracing
30
- optimize: Whether to apply optimizations
31
- check_trace: Whether to verify trace correctness
32
- strict: Whether to record all operations
33
34
Returns:
35
TracedModule or function
36
"""
37
38
def jit.load(f, map_location=None, _extra_files=None):
39
"""Load TorchScript model from file."""
40
41
def jit.save(m, f, _extra_files=None):
42
"""Save TorchScript model to file."""
43
44
class jit.ScriptModule(nn.Module):
45
"""TorchScript compiled module."""
46
def save(self, f, _extra_files=None): ...
47
def code(self) -> str: ...
48
def graph(self): ...
49
def code_with_constants(self) -> Tuple[str, List[Tensor]]: ...
50
51
def jit.freeze(mod, preserved_attrs=None, optimize_numerics=True):
52
"""Freeze TorchScript module for inference."""
53
54
def jit.optimize_for_inference(mod, other_methods=None):
55
"""Optimize TorchScript module for inference."""
56
57
def jit.enable_onednn_fusion(enabled: bool):
58
"""Enable/disable OneDNN fusion optimization."""
59
60
def jit.set_fusion_strategy(strategy: List[Tuple[str, bool]]):
61
"""Set fusion strategy for optimization."""
62
```
63
64
### Model Export (torch.export)
65
66
Export PyTorch models for deployment and optimization.
67
68
```python { .api }
69
def export.export(mod: nn.Module, args, kwargs=None, *, dynamic_shapes=None, strict=True) -> ExportedProgram:
70
"""
71
Export PyTorch module to exportable format.
72
73
Parameters:
74
- mod: Module to export
75
- args: Example arguments
76
- kwargs: Example keyword arguments
77
- dynamic_shapes: Dynamic shape specifications
78
- strict: Whether to enforce strict export
79
80
Returns:
81
ExportedProgram
82
"""
83
84
class export.ExportedProgram:
85
"""Exported PyTorch program."""
86
def module(self) -> nn.Module: ...
87
def graph_module(self): ...
88
def graph_signature(self): ...
89
def call_spec(self): ...
90
def verifier(self): ...
91
def state_dict(self) -> Dict[str, Any]: ...
92
def named_parameters(self): ...
93
def named_buffers(self): ...
94
95
def export.save(ep: ExportedProgram, f) -> None:
96
"""Save exported program to file."""
97
98
def export.load(f) -> ExportedProgram:
99
"""Load exported program from file."""
100
```
101
102
### Model Compilation (torch.compile)
103
104
Compile PyTorch models for performance optimization.
105
106
```python { .api }
107
def compile(model=None, *, fullgraph=False, dynamic=None, backend="inductor", mode=None, options=None, disable=False):
108
"""
109
Compile PyTorch model for optimization.
110
111
Parameters:
112
- model: Model to compile (or use as decorator)
113
- fullgraph: Whether to compile the entire graph
114
- dynamic: Enable dynamic shapes
115
- backend: Compilation backend ("inductor", "aot_eager", etc.)
116
- mode: Compilation mode ("default", "reduce-overhead", "max-autotune")
117
- options: Backend-specific options
118
- disable: Disable compilation
119
120
Returns:
121
Compiled model
122
"""
123
124
@compile
125
def compiled_function(x):
126
"""Example of function compilation."""
127
return x * 2 + 1
128
129
# Alternative usage
130
compiled_model = torch.compile(model, mode="max-autotune")
131
```
132
133
### Graph Transformations (torch.fx)
134
135
Symbolic tracing and graph manipulation for model analysis and optimization.
136
137
```python { .api }
138
class fx.GraphModule(nn.Module):
139
"""Module with FX graph representation."""
140
def __init__(self, root, graph, class_name='GraphModule'): ...
141
def recompile(self): ...
142
def code(self) -> str: ...
143
def graph(self): ...
144
def print_readable(self, print_output=True): ...
145
146
def fx.symbolic_trace(root, concrete_args=None, meta_args=None, _force_outplace=False) -> GraphModule:
147
"""
148
Symbolically trace PyTorch module.
149
150
Parameters:
151
- root: Module or function to trace
152
- concrete_args: Arguments to keep concrete
153
- meta_args: Meta tensor arguments
154
155
Returns:
156
GraphModule with traced computation graph
157
"""
158
159
class fx.Tracer:
160
"""Tracer for symbolic execution."""
161
def trace(self, root, concrete_args=None): ...
162
def call_module(self, m, forward, args, kwargs): ...
163
def call_function(self, target, args, kwargs): ...
164
def call_method(self, target, args, kwargs): ...
165
166
class fx.Graph:
167
"""Computational graph representation."""
168
def nodes(self): ...
169
def create_node(self, op, target, args=None, kwargs=None, name=None, type_expr=None): ...
170
def erase_node(self, to_erase): ...
171
def inserting_before(self, n): ...
172
def inserting_after(self, n): ...
173
def lint(self): ...
174
def print_tabular(self): ...
175
176
class fx.Node:
177
"""Node in FX graph."""
178
def replace_all_uses_with(self, replace_with): ...
179
def replace_input_with(self, old_input, new_input): ...
180
def append(self, x): ...
181
def prepend(self, x): ...
182
183
def fx.replace_pattern(gm: GraphModule, pattern, replacement) -> List[Match]:
184
"""Replace patterns in graph."""
185
186
class fx.Interpreter:
187
"""Base class for FX graph interpreters."""
188
def run(self, *args, **kwargs): ...
189
def run_node(self, n): ...
190
def call_function(self, target, args, kwargs): ...
191
def call_method(self, target, args, kwargs): ...
192
def call_module(self, target, args, kwargs): ...
193
```
194
195
### Quantization (torch.quantization)
196
197
Model quantization for efficient deployment.
198
199
```python { .api }
200
def quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False, remove_qconfig=True):
201
"""
202
Dynamic quantization of model.
203
204
Parameters:
205
- model: Model to quantize
206
- qconfig_spec: Quantization configuration
207
- dtype: Target quantized data type
208
- mapping: Custom op mapping
209
- inplace: Whether to modify model in-place
210
211
Returns:
212
Quantized model
213
"""
214
215
def quantization.quantize(model, run_fn, run_args, mapping=None, inplace=False):
216
"""Post-training static quantization."""
217
218
def quantization.prepare(model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None):
219
"""Prepare model for quantization aware training."""
220
221
def quantization.convert(model, mapping=None, inplace=False, remove_qconfig=True, convert_custom_config_dict=None):
222
"""Convert prepared model to quantized version."""
223
224
def quantization.prepare_qat(model, mapping=None, inplace=False):
225
"""Prepare model for quantization aware training."""
226
227
class quantization.QuantStub(nn.Module):
228
"""Quantization stub for marking quantization points."""
229
def __init__(self, qconfig=None): ...
230
def forward(self, x): ...
231
232
class quantization.DeQuantStub(nn.Module):
233
"""Dequantization stub for marking dequantization points."""
234
def __init__(self): ...
235
def forward(self, x): ...
236
237
class quantization.QConfig:
238
"""Quantization configuration."""
239
def __init__(self, activation, weight): ...
240
241
def quantization.get_default_qconfig(backend='fbgemm'):
242
"""Get default quantization configuration."""
243
244
def quantization.get_default_qat_qconfig(backend='fbgemm'):
245
"""Get default QAT quantization configuration."""
246
247
class quantization.FakeQuantize(nn.Module):
248
"""Fake quantization for QAT."""
249
def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): ...
250
def forward(self, X): ...
251
def calculate_qparams(self): ...
252
```
253
254
### ONNX Export (torch.onnx)
255
256
Export PyTorch models to ONNX format for interoperability.
257
258
```python { .api }
259
def onnx.export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
260
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
261
opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None,
262
custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False):
263
"""
264
Export PyTorch model to ONNX format.
265
266
Parameters:
267
- model: PyTorch model to export
268
- args: Model input arguments
269
- f: File path or file-like object to save to
270
- export_params: Whether to export parameters
271
- verbose: Enable verbose output
272
- training: Training mode (EVAL, TRAINING, PRESERVE)
273
- input_names: Names for input nodes
274
- output_names: Names for output nodes
275
- opset_version: ONNX opset version
276
- dynamic_axes: Dynamic input/output axes
277
- custom_opsets: Custom operator sets
278
"""
279
280
def onnx.dynamo_export(model, *model_args, export_options=None, **model_kwargs) -> ONNXProgram:
281
"""Export using torch.export and Dynamo."""
282
283
class onnx.ONNXProgram:
284
"""ONNX program representation."""
285
def save(self, destination): ...
286
def model_proto(self): ...
287
288
def onnx.load(f) -> ModelProto:
289
"""Load ONNX model."""
290
291
def onnx.save(model, f, export_params=True):
292
"""Save ONNX model to file."""
293
294
class onnx.TrainingMode(Enum):
295
"""Training mode for ONNX export."""
296
EVAL = 0
297
TRAINING = 1
298
PRESERVE = 2
299
300
class onnx.OperatorExportTypes(Enum):
301
"""Operator export types."""
302
ONNX = 0
303
ONNX_ATEN = 1
304
ONNX_ATEN_FALLBACK = 2
305
```
306
307
### Mobile Deployment (torch.utils.mobile_optimizer)
308
309
Optimization utilities for mobile deployment.
310
311
```python { .api }
312
def utils.mobile_optimizer.optimize_for_mobile(script_module, optimization_blocklist=None, preserved_methods=None, backend='CPU'):
313
"""
314
Optimize TorchScript module for mobile deployment.
315
316
Parameters:
317
- script_module: TorchScript module to optimize
318
- optimization_blocklist: Operations to exclude from optimization
319
- preserved_methods: Methods to preserve during optimization
320
- backend: Target backend ('CPU', 'Vulkan', 'Metal')
321
322
Returns:
323
Optimized TorchScript module
324
"""
325
326
class utils.mobile_optimizer.LiteScriptModule:
327
"""Lightweight script module for mobile."""
328
def forward(self, *args): ...
329
def get_debug_info(self): ...
330
```
331
332
### TensorRT Integration
333
334
NVIDIA TensorRT integration for GPU inference optimization.
335
336
```python { .api }
337
def tensorrt.compile(model, inputs, enabled_precisions={torch.float}, workspace_size=1 << 22,
338
min_block_size=3, torch_executed_ops=None, torch_executed_modules=None):
339
"""
340
Compile model with TensorRT.
341
342
Parameters:
343
- model: PyTorch model to compile
344
- inputs: Example inputs for compilation
345
- enabled_precisions: Allowed precision types
346
- workspace_size: TensorRT workspace size
347
- min_block_size: Minimum block size for TensorRT subgraphs
348
349
Returns:
350
TensorRT compiled model
351
"""
352
```
353
354
### Automatic Mixed Precision (torch.amp)
355
356
Automatic mixed precision training for performance and memory optimization.
357
358
```python { .api }
359
class amp.GradScaler:
360
"""Gradient scaler for mixed precision training."""
361
def __init__(self, init_scale=2**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
362
"""
363
Parameters:
364
- init_scale: Initial scale factor
365
- growth_factor: Scale growth factor
366
- backoff_factor: Scale reduction factor
367
- growth_interval: Steps between scale increases
368
- enabled: Whether scaler is enabled
369
"""
370
371
def scale(self, outputs): ...
372
def step(self, optimizer): ...
373
def update(self): ...
374
def unscale_(self, optimizer): ...
375
def get_scale(self): ...
376
def get_growth_factor(self): ...
377
def set_growth_factor(self, new_factor): ...
378
def get_backoff_factor(self): ...
379
def set_backoff_factor(self, new_factor): ...
380
def get_growth_interval(self): ...
381
def set_growth_interval(self, new_interval): ...
382
def is_enabled(self): ...
383
def state_dict(self): ...
384
def load_state_dict(self, state_dict): ...
385
386
def amp.autocast(device_type='cuda', dtype=None, enabled=True, cache_enabled=None):
387
"""
388
Context manager for automatic mixed precision.
389
390
Parameters:
391
- device_type: Device type ('cuda', 'cpu', 'xpu')
392
- dtype: Target dtype (torch.float16, torch.bfloat16)
393
- enabled: Whether autocast is enabled
394
- cache_enabled: Whether to cache autocast state
395
"""
396
```
397
398
### Model Optimization and Pruning (torch.ao)
399
400
Advanced optimization techniques including pruning and sparsity.
401
402
```python { .api }
403
def ao.pruning.prune_low_magnitude(model, amount, importance_scores=None, structured=False, dim=None):
404
"""
405
Prune model by removing low magnitude weights.
406
407
Parameters:
408
- model: Model to prune
409
- amount: Fraction of weights to prune
410
- importance_scores: Custom importance scores
411
- structured: Whether to use structured pruning
412
- dim: Dimension for structured pruning
413
414
Returns:
415
Pruned model
416
"""
417
418
class ao.pruning.WeightNormSparsifier:
419
"""Weight norm based sparsifier."""
420
def __init__(self, sparsity_level=0.5): ...
421
def update_mask(self, module, tensor_name, **kwargs): ...
422
423
class ao.quantization.QConfigMapping:
424
"""Quantization configuration mapping."""
425
def set_global(self, qconfig): ...
426
def set_object_type(self, object_type, qconfig): ...
427
def set_module_name(self, module_name, qconfig): ...
428
429
def ao.quantization.get_default_qconfig_mapping(backend='x86'):
430
"""Get default quantization configuration mapping."""
431
432
class ao.quantization.FusedMovingAvgObsFakeQuantize(nn.Module):
433
"""Fused moving average observer fake quantize."""
434
def __init__(self, observer=MovingAverageMinMaxObserver, **observer_kwargs): ...
435
```
436
437
## Usage Examples
438
439
### TorchScript Compilation
440
441
```python
442
import torch
443
import torch.nn as nn
444
445
# Define model
446
class SimpleModel(nn.Module):
447
def __init__(self):
448
super(SimpleModel, self).__init__()
449
self.linear = nn.Linear(10, 5)
450
451
def forward(self, x):
452
return torch.relu(self.linear(x))
453
454
model = SimpleModel()
455
model.eval()
456
457
# Script compilation
458
scripted_model = torch.jit.script(model)
459
print(scripted_model.code)
460
461
# Trace compilation
462
example_input = torch.randn(1, 10)
463
traced_model = torch.jit.trace(model, example_input)
464
465
# Save/load
466
torch.jit.save(scripted_model, 'model_scripted.pt')
467
loaded_model = torch.jit.load('model_scripted.pt')
468
469
# Optimization for inference
470
optimized_model = torch.jit.optimize_for_inference(scripted_model)
471
472
print("TorchScript compilation completed")
473
```
474
475
### Model Export and Deployment
476
477
```python
478
import torch
479
import torch.nn as nn
480
from torch.export import export
481
482
# Define model
483
class ExportModel(nn.Module):
484
def __init__(self):
485
super().__init__()
486
self.conv = nn.Conv2d(3, 16, 3, padding=1)
487
self.pool = nn.AdaptiveAvgPool2d((1, 1))
488
self.fc = nn.Linear(16, 10)
489
490
def forward(self, x):
491
x = torch.relu(self.conv(x))
492
x = self.pool(x)
493
x = x.flatten(1)
494
return self.fc(x)
495
496
model = ExportModel()
497
example_input = torch.randn(1, 3, 32, 32)
498
499
# Export to ExportedProgram
500
exported_program = export(model, (example_input,))
501
502
# Save exported program
503
torch.export.save(exported_program, 'exported_model.pt2')
504
505
# Load exported program
506
loaded_program = torch.export.load('exported_model.pt2')
507
508
# Use exported program
509
output = loaded_program.module()(example_input)
510
print(f"Export completed, output shape: {output.shape}")
511
```
512
513
### Torch Compile Usage
514
515
```python
516
import torch
517
import torch.nn as nn
518
519
# Define model
520
model = nn.Sequential(
521
nn.Linear(100, 200),
522
nn.ReLU(),
523
nn.Linear(200, 100),
524
nn.ReLU(),
525
nn.Linear(100, 10)
526
)
527
528
# Compile with different modes
529
default_compiled = torch.compile(model)
530
fast_compiled = torch.compile(model, mode="reduce-overhead")
531
optimal_compiled = torch.compile(model, mode="max-autotune")
532
533
# Use as decorator
534
@torch.compile
535
def custom_function(x, y):
536
return x.matmul(y) + x.sum()
537
538
# Example usage
539
x = torch.randn(32, 100)
540
y = torch.randn(100, 50)
541
542
# Compiled function
543
result = custom_function(x, y)
544
545
# Compiled model
546
output = optimal_compiled(x)
547
548
print(f"Torch compile completed, output shape: {output.shape}")
549
```
550
551
### Quantization Example
552
553
```python
554
import torch
555
import torch.nn as nn
556
import torch.quantization as quant
557
558
# Define model
559
class QuantModel(nn.Module):
560
def __init__(self):
561
super().__init__()
562
self.quant = quant.QuantStub()
563
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
564
self.relu1 = nn.ReLU()
565
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
566
self.relu2 = nn.ReLU()
567
self.pool = nn.AdaptiveAvgPool2d((1, 1))
568
self.fc = nn.Linear(64, 10)
569
self.dequant = quant.DeQuantStub()
570
571
def forward(self, x):
572
x = self.quant(x)
573
x = self.relu1(self.conv1(x))
574
x = self.relu2(self.conv2(x))
575
x = self.pool(x)
576
x = x.flatten(1)
577
x = self.fc(x)
578
x = self.dequant(x)
579
return x
580
581
model = QuantModel()
582
model.eval()
583
584
# Dynamic quantization
585
quantized_model = quant.quantize_dynamic(
586
model, {nn.Linear}, dtype=torch.qint8
587
)
588
589
# Post-training static quantization
590
model.qconfig = quant.get_default_qconfig('fbgemm')
591
prepared_model = quant.prepare(model)
592
593
# Calibration (example data)
594
for _ in range(10):
595
calibration_data = torch.randn(1, 3, 32, 32)
596
prepared_model(calibration_data)
597
598
# Convert to quantized model
599
quantized_static_model = quant.convert(prepared_model)
600
601
print("Quantization completed")
602
print(f"Original model size: {sum(p.numel() for p in model.parameters())}")
603
print(f"Quantized model parameters: {sum(p.numel() for p in quantized_model.parameters())}")
604
```
605
606
### ONNX Export
607
608
```python
609
import torch
610
import torch.nn as nn
611
import torch.onnx
612
613
# Define model
614
class ONNXModel(nn.Module):
615
def __init__(self):
616
super().__init__()
617
self.backbone = nn.Sequential(
618
nn.Conv2d(3, 64, 7, stride=2, padding=3),
619
nn.BatchNorm2d(64),
620
nn.ReLU(inplace=True),
621
nn.AdaptiveAvgPool2d((1, 1)),
622
nn.Flatten(),
623
nn.Linear(64, 1000)
624
)
625
626
def forward(self, x):
627
return self.backbone(x)
628
629
model = ONNXModel()
630
model.eval()
631
632
# Example input
633
dummy_input = torch.randn(1, 3, 224, 224)
634
635
# Export to ONNX
636
torch.onnx.export(
637
model,
638
dummy_input,
639
"model.onnx",
640
export_params=True,
641
opset_version=11,
642
do_constant_folding=True,
643
input_names=['input'],
644
output_names=['output'],
645
dynamic_axes={
646
'input': {0: 'batch_size'},
647
'output': {0: 'batch_size'}
648
}
649
)
650
651
print("ONNX export completed")
652
```
653
654
### FX Graph Manipulation
655
656
```python
657
import torch
658
import torch.nn as nn
659
import torch.fx as fx
660
661
# Define model
662
class FXModel(nn.Module):
663
def __init__(self):
664
super().__init__()
665
self.conv1 = nn.Conv2d(3, 32, 3)
666
self.conv2 = nn.Conv2d(32, 64, 3)
667
self.relu = nn.ReLU()
668
self.pool = nn.AdaptiveAvgPool2d((1, 1))
669
self.fc = nn.Linear(64, 10)
670
671
def forward(self, x):
672
x = self.relu(self.conv1(x))
673
x = self.relu(self.conv2(x))
674
x = self.pool(x)
675
x = x.flatten(1)
676
x = self.fc(x)
677
return x
678
679
# Symbolic tracing
680
model = FXModel()
681
traced = fx.symbolic_trace(model)
682
683
# Print graph
684
print("Original graph:")
685
traced.graph.print_tabular()
686
687
# Graph manipulation - replace ReLU with GELU
688
for node in traced.graph.nodes:
689
if node.target == torch.relu:
690
with traced.graph.inserting_after(node):
691
new_node = traced.graph.call_function(torch.nn.functional.gelu, args=(node.args[0],))
692
node.replace_all_uses_with(new_node)
693
traced.graph.erase_node(node)
694
695
# Recompile
696
traced.recompile()
697
698
print("\nModified graph:")
699
traced.graph.print_tabular()
700
701
# Test modified model
702
test_input = torch.randn(1, 3, 32, 32)
703
output = traced(test_input)
704
print(f"FX transformation completed, output shape: {output.shape}")
705
```
706
707
### Mixed Precision Training
708
709
```python
710
import torch
711
import torch.nn as nn
712
import torch.optim as optim
713
from torch.cuda.amp import autocast, GradScaler
714
715
# Define model and training setup
716
model = nn.Sequential(
717
nn.Linear(1000, 500),
718
nn.ReLU(),
719
nn.Linear(500, 100),
720
nn.ReLU(),
721
nn.Linear(100, 10)
722
).cuda()
723
724
optimizer = optim.Adam(model.parameters(), lr=0.001)
725
criterion = nn.CrossEntropyLoss()
726
scaler = GradScaler()
727
728
# Training loop with mixed precision
729
model.train()
730
for epoch in range(5):
731
for batch_idx in range(100): # Simulate 100 batches
732
# Generate dummy data
733
data = torch.randn(32, 1000).cuda()
734
targets = torch.randint(0, 10, (32,)).cuda()
735
736
optimizer.zero_grad()
737
738
# Forward pass with autocast
739
with autocast():
740
outputs = model(data)
741
loss = criterion(outputs, targets)
742
743
# Backward pass with gradient scaling
744
scaler.scale(loss).backward()
745
scaler.step(optimizer)
746
scaler.update()
747
748
if batch_idx % 25 == 0:
749
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Scale: {scaler.get_scale()}")
750
751
print("Mixed precision training completed")
752
```