or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-features.mddevices-distributed.mdindex.mdmathematical-functions.mdneural-networks.mdtensor-operations.mdtraining.md

advanced-features.mddocs/

0

# Advanced Features

1

2

JIT compilation, model export, graph transformations, quantization, and deployment utilities for optimizing and deploying PyTorch models in production environments.

3

4

## Capabilities

5

6

### JIT Compilation (torch.jit)

7

8

TorchScript compilation for model optimization and deployment.

9

10

```python { .api }

11

def jit.script(obj, optimize=None, _frames_up=0, _rcb=None):

12

"""

13

Compile Python code to TorchScript.

14

15

Parameters:

16

- obj: Function, method, or class to compile

17

- optimize: Whether to apply optimizations

18

19

Returns:

20

ScriptModule or ScriptFunction

21

"""

22

23

def jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=None):

24

"""

25

Trace function execution to create TorchScript.

26

27

Parameters:

28

- func: Function or module to trace

29

- example_inputs: Example inputs for tracing

30

- optimize: Whether to apply optimizations

31

- check_trace: Whether to verify trace correctness

32

- strict: Whether to record all operations

33

34

Returns:

35

TracedModule or function

36

"""

37

38

def jit.load(f, map_location=None, _extra_files=None):

39

"""Load TorchScript model from file."""

40

41

def jit.save(m, f, _extra_files=None):

42

"""Save TorchScript model to file."""

43

44

class jit.ScriptModule(nn.Module):

45

"""TorchScript compiled module."""

46

def save(self, f, _extra_files=None): ...

47

def code(self) -> str: ...

48

def graph(self): ...

49

def code_with_constants(self) -> Tuple[str, List[Tensor]]: ...

50

51

def jit.freeze(mod, preserved_attrs=None, optimize_numerics=True):

52

"""Freeze TorchScript module for inference."""

53

54

def jit.optimize_for_inference(mod, other_methods=None):

55

"""Optimize TorchScript module for inference."""

56

57

def jit.enable_onednn_fusion(enabled: bool):

58

"""Enable/disable OneDNN fusion optimization."""

59

60

def jit.set_fusion_strategy(strategy: List[Tuple[str, bool]]):

61

"""Set fusion strategy for optimization."""

62

```

63

64

### Model Export (torch.export)

65

66

Export PyTorch models for deployment and optimization.

67

68

```python { .api }

69

def export.export(mod: nn.Module, args, kwargs=None, *, dynamic_shapes=None, strict=True) -> ExportedProgram:

70

"""

71

Export PyTorch module to exportable format.

72

73

Parameters:

74

- mod: Module to export

75

- args: Example arguments

76

- kwargs: Example keyword arguments

77

- dynamic_shapes: Dynamic shape specifications

78

- strict: Whether to enforce strict export

79

80

Returns:

81

ExportedProgram

82

"""

83

84

class export.ExportedProgram:

85

"""Exported PyTorch program."""

86

def module(self) -> nn.Module: ...

87

def graph_module(self): ...

88

def graph_signature(self): ...

89

def call_spec(self): ...

90

def verifier(self): ...

91

def state_dict(self) -> Dict[str, Any]: ...

92

def named_parameters(self): ...

93

def named_buffers(self): ...

94

95

def export.save(ep: ExportedProgram, f) -> None:

96

"""Save exported program to file."""

97

98

def export.load(f) -> ExportedProgram:

99

"""Load exported program from file."""

100

```

101

102

### Model Compilation (torch.compile)

103

104

Compile PyTorch models for performance optimization.

105

106

```python { .api }

107

def compile(model=None, *, fullgraph=False, dynamic=None, backend="inductor", mode=None, options=None, disable=False):

108

"""

109

Compile PyTorch model for optimization.

110

111

Parameters:

112

- model: Model to compile (or use as decorator)

113

- fullgraph: Whether to compile the entire graph

114

- dynamic: Enable dynamic shapes

115

- backend: Compilation backend ("inductor", "aot_eager", etc.)

116

- mode: Compilation mode ("default", "reduce-overhead", "max-autotune")

117

- options: Backend-specific options

118

- disable: Disable compilation

119

120

Returns:

121

Compiled model

122

"""

123

124

@compile

125

def compiled_function(x):

126

"""Example of function compilation."""

127

return x * 2 + 1

128

129

# Alternative usage

130

compiled_model = torch.compile(model, mode="max-autotune")

131

```

132

133

### Graph Transformations (torch.fx)

134

135

Symbolic tracing and graph manipulation for model analysis and optimization.

136

137

```python { .api }

138

class fx.GraphModule(nn.Module):

139

"""Module with FX graph representation."""

140

def __init__(self, root, graph, class_name='GraphModule'): ...

141

def recompile(self): ...

142

def code(self) -> str: ...

143

def graph(self): ...

144

def print_readable(self, print_output=True): ...

145

146

def fx.symbolic_trace(root, concrete_args=None, meta_args=None, _force_outplace=False) -> GraphModule:

147

"""

148

Symbolically trace PyTorch module.

149

150

Parameters:

151

- root: Module or function to trace

152

- concrete_args: Arguments to keep concrete

153

- meta_args: Meta tensor arguments

154

155

Returns:

156

GraphModule with traced computation graph

157

"""

158

159

class fx.Tracer:

160

"""Tracer for symbolic execution."""

161

def trace(self, root, concrete_args=None): ...

162

def call_module(self, m, forward, args, kwargs): ...

163

def call_function(self, target, args, kwargs): ...

164

def call_method(self, target, args, kwargs): ...

165

166

class fx.Graph:

167

"""Computational graph representation."""

168

def nodes(self): ...

169

def create_node(self, op, target, args=None, kwargs=None, name=None, type_expr=None): ...

170

def erase_node(self, to_erase): ...

171

def inserting_before(self, n): ...

172

def inserting_after(self, n): ...

173

def lint(self): ...

174

def print_tabular(self): ...

175

176

class fx.Node:

177

"""Node in FX graph."""

178

def replace_all_uses_with(self, replace_with): ...

179

def replace_input_with(self, old_input, new_input): ...

180

def append(self, x): ...

181

def prepend(self, x): ...

182

183

def fx.replace_pattern(gm: GraphModule, pattern, replacement) -> List[Match]:

184

"""Replace patterns in graph."""

185

186

class fx.Interpreter:

187

"""Base class for FX graph interpreters."""

188

def run(self, *args, **kwargs): ...

189

def run_node(self, n): ...

190

def call_function(self, target, args, kwargs): ...

191

def call_method(self, target, args, kwargs): ...

192

def call_module(self, target, args, kwargs): ...

193

```

194

195

### Quantization (torch.quantization)

196

197

Model quantization for efficient deployment.

198

199

```python { .api }

200

def quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False, remove_qconfig=True):

201

"""

202

Dynamic quantization of model.

203

204

Parameters:

205

- model: Model to quantize

206

- qconfig_spec: Quantization configuration

207

- dtype: Target quantized data type

208

- mapping: Custom op mapping

209

- inplace: Whether to modify model in-place

210

211

Returns:

212

Quantized model

213

"""

214

215

def quantization.quantize(model, run_fn, run_args, mapping=None, inplace=False):

216

"""Post-training static quantization."""

217

218

def quantization.prepare(model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None):

219

"""Prepare model for quantization aware training."""

220

221

def quantization.convert(model, mapping=None, inplace=False, remove_qconfig=True, convert_custom_config_dict=None):

222

"""Convert prepared model to quantized version."""

223

224

def quantization.prepare_qat(model, mapping=None, inplace=False):

225

"""Prepare model for quantization aware training."""

226

227

class quantization.QuantStub(nn.Module):

228

"""Quantization stub for marking quantization points."""

229

def __init__(self, qconfig=None): ...

230

def forward(self, x): ...

231

232

class quantization.DeQuantStub(nn.Module):

233

"""Dequantization stub for marking dequantization points."""

234

def __init__(self): ...

235

def forward(self, x): ...

236

237

class quantization.QConfig:

238

"""Quantization configuration."""

239

def __init__(self, activation, weight): ...

240

241

def quantization.get_default_qconfig(backend='fbgemm'):

242

"""Get default quantization configuration."""

243

244

def quantization.get_default_qat_qconfig(backend='fbgemm'):

245

"""Get default QAT quantization configuration."""

246

247

class quantization.FakeQuantize(nn.Module):

248

"""Fake quantization for QAT."""

249

def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): ...

250

def forward(self, X): ...

251

def calculate_qparams(self): ...

252

```

253

254

### ONNX Export (torch.onnx)

255

256

Export PyTorch models to ONNX format for interoperability.

257

258

```python { .api }

259

def onnx.export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,

260

input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,

261

opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None,

262

custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False):

263

"""

264

Export PyTorch model to ONNX format.

265

266

Parameters:

267

- model: PyTorch model to export

268

- args: Model input arguments

269

- f: File path or file-like object to save to

270

- export_params: Whether to export parameters

271

- verbose: Enable verbose output

272

- training: Training mode (EVAL, TRAINING, PRESERVE)

273

- input_names: Names for input nodes

274

- output_names: Names for output nodes

275

- opset_version: ONNX opset version

276

- dynamic_axes: Dynamic input/output axes

277

- custom_opsets: Custom operator sets

278

"""

279

280

def onnx.dynamo_export(model, *model_args, export_options=None, **model_kwargs) -> ONNXProgram:

281

"""Export using torch.export and Dynamo."""

282

283

class onnx.ONNXProgram:

284

"""ONNX program representation."""

285

def save(self, destination): ...

286

def model_proto(self): ...

287

288

def onnx.load(f) -> ModelProto:

289

"""Load ONNX model."""

290

291

def onnx.save(model, f, export_params=True):

292

"""Save ONNX model to file."""

293

294

class onnx.TrainingMode(Enum):

295

"""Training mode for ONNX export."""

296

EVAL = 0

297

TRAINING = 1

298

PRESERVE = 2

299

300

class onnx.OperatorExportTypes(Enum):

301

"""Operator export types."""

302

ONNX = 0

303

ONNX_ATEN = 1

304

ONNX_ATEN_FALLBACK = 2

305

```

306

307

### Mobile Deployment (torch.utils.mobile_optimizer)

308

309

Optimization utilities for mobile deployment.

310

311

```python { .api }

312

def utils.mobile_optimizer.optimize_for_mobile(script_module, optimization_blocklist=None, preserved_methods=None, backend='CPU'):

313

"""

314

Optimize TorchScript module for mobile deployment.

315

316

Parameters:

317

- script_module: TorchScript module to optimize

318

- optimization_blocklist: Operations to exclude from optimization

319

- preserved_methods: Methods to preserve during optimization

320

- backend: Target backend ('CPU', 'Vulkan', 'Metal')

321

322

Returns:

323

Optimized TorchScript module

324

"""

325

326

class utils.mobile_optimizer.LiteScriptModule:

327

"""Lightweight script module for mobile."""

328

def forward(self, *args): ...

329

def get_debug_info(self): ...

330

```

331

332

### TensorRT Integration

333

334

NVIDIA TensorRT integration for GPU inference optimization.

335

336

```python { .api }

337

def tensorrt.compile(model, inputs, enabled_precisions={torch.float}, workspace_size=1 << 22,

338

min_block_size=3, torch_executed_ops=None, torch_executed_modules=None):

339

"""

340

Compile model with TensorRT.

341

342

Parameters:

343

- model: PyTorch model to compile

344

- inputs: Example inputs for compilation

345

- enabled_precisions: Allowed precision types

346

- workspace_size: TensorRT workspace size

347

- min_block_size: Minimum block size for TensorRT subgraphs

348

349

Returns:

350

TensorRT compiled model

351

"""

352

```

353

354

### Automatic Mixed Precision (torch.amp)

355

356

Automatic mixed precision training for performance and memory optimization.

357

358

```python { .api }

359

class amp.GradScaler:

360

"""Gradient scaler for mixed precision training."""

361

def __init__(self, init_scale=2**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):

362

"""

363

Parameters:

364

- init_scale: Initial scale factor

365

- growth_factor: Scale growth factor

366

- backoff_factor: Scale reduction factor

367

- growth_interval: Steps between scale increases

368

- enabled: Whether scaler is enabled

369

"""

370

371

def scale(self, outputs): ...

372

def step(self, optimizer): ...

373

def update(self): ...

374

def unscale_(self, optimizer): ...

375

def get_scale(self): ...

376

def get_growth_factor(self): ...

377

def set_growth_factor(self, new_factor): ...

378

def get_backoff_factor(self): ...

379

def set_backoff_factor(self, new_factor): ...

380

def get_growth_interval(self): ...

381

def set_growth_interval(self, new_interval): ...

382

def is_enabled(self): ...

383

def state_dict(self): ...

384

def load_state_dict(self, state_dict): ...

385

386

def amp.autocast(device_type='cuda', dtype=None, enabled=True, cache_enabled=None):

387

"""

388

Context manager for automatic mixed precision.

389

390

Parameters:

391

- device_type: Device type ('cuda', 'cpu', 'xpu')

392

- dtype: Target dtype (torch.float16, torch.bfloat16)

393

- enabled: Whether autocast is enabled

394

- cache_enabled: Whether to cache autocast state

395

"""

396

```

397

398

### Model Optimization and Pruning (torch.ao)

399

400

Advanced optimization techniques including pruning and sparsity.

401

402

```python { .api }

403

def ao.pruning.prune_low_magnitude(model, amount, importance_scores=None, structured=False, dim=None):

404

"""

405

Prune model by removing low magnitude weights.

406

407

Parameters:

408

- model: Model to prune

409

- amount: Fraction of weights to prune

410

- importance_scores: Custom importance scores

411

- structured: Whether to use structured pruning

412

- dim: Dimension for structured pruning

413

414

Returns:

415

Pruned model

416

"""

417

418

class ao.pruning.WeightNormSparsifier:

419

"""Weight norm based sparsifier."""

420

def __init__(self, sparsity_level=0.5): ...

421

def update_mask(self, module, tensor_name, **kwargs): ...

422

423

class ao.quantization.QConfigMapping:

424

"""Quantization configuration mapping."""

425

def set_global(self, qconfig): ...

426

def set_object_type(self, object_type, qconfig): ...

427

def set_module_name(self, module_name, qconfig): ...

428

429

def ao.quantization.get_default_qconfig_mapping(backend='x86'):

430

"""Get default quantization configuration mapping."""

431

432

class ao.quantization.FusedMovingAvgObsFakeQuantize(nn.Module):

433

"""Fused moving average observer fake quantize."""

434

def __init__(self, observer=MovingAverageMinMaxObserver, **observer_kwargs): ...

435

```

436

437

## Usage Examples

438

439

### TorchScript Compilation

440

441

```python

442

import torch

443

import torch.nn as nn

444

445

# Define model

446

class SimpleModel(nn.Module):

447

def __init__(self):

448

super(SimpleModel, self).__init__()

449

self.linear = nn.Linear(10, 5)

450

451

def forward(self, x):

452

return torch.relu(self.linear(x))

453

454

model = SimpleModel()

455

model.eval()

456

457

# Script compilation

458

scripted_model = torch.jit.script(model)

459

print(scripted_model.code)

460

461

# Trace compilation

462

example_input = torch.randn(1, 10)

463

traced_model = torch.jit.trace(model, example_input)

464

465

# Save/load

466

torch.jit.save(scripted_model, 'model_scripted.pt')

467

loaded_model = torch.jit.load('model_scripted.pt')

468

469

# Optimization for inference

470

optimized_model = torch.jit.optimize_for_inference(scripted_model)

471

472

print("TorchScript compilation completed")

473

```

474

475

### Model Export and Deployment

476

477

```python

478

import torch

479

import torch.nn as nn

480

from torch.export import export

481

482

# Define model

483

class ExportModel(nn.Module):

484

def __init__(self):

485

super().__init__()

486

self.conv = nn.Conv2d(3, 16, 3, padding=1)

487

self.pool = nn.AdaptiveAvgPool2d((1, 1))

488

self.fc = nn.Linear(16, 10)

489

490

def forward(self, x):

491

x = torch.relu(self.conv(x))

492

x = self.pool(x)

493

x = x.flatten(1)

494

return self.fc(x)

495

496

model = ExportModel()

497

example_input = torch.randn(1, 3, 32, 32)

498

499

# Export to ExportedProgram

500

exported_program = export(model, (example_input,))

501

502

# Save exported program

503

torch.export.save(exported_program, 'exported_model.pt2')

504

505

# Load exported program

506

loaded_program = torch.export.load('exported_model.pt2')

507

508

# Use exported program

509

output = loaded_program.module()(example_input)

510

print(f"Export completed, output shape: {output.shape}")

511

```

512

513

### Torch Compile Usage

514

515

```python

516

import torch

517

import torch.nn as nn

518

519

# Define model

520

model = nn.Sequential(

521

nn.Linear(100, 200),

522

nn.ReLU(),

523

nn.Linear(200, 100),

524

nn.ReLU(),

525

nn.Linear(100, 10)

526

)

527

528

# Compile with different modes

529

default_compiled = torch.compile(model)

530

fast_compiled = torch.compile(model, mode="reduce-overhead")

531

optimal_compiled = torch.compile(model, mode="max-autotune")

532

533

# Use as decorator

534

@torch.compile

535

def custom_function(x, y):

536

return x.matmul(y) + x.sum()

537

538

# Example usage

539

x = torch.randn(32, 100)

540

y = torch.randn(100, 50)

541

542

# Compiled function

543

result = custom_function(x, y)

544

545

# Compiled model

546

output = optimal_compiled(x)

547

548

print(f"Torch compile completed, output shape: {output.shape}")

549

```

550

551

### Quantization Example

552

553

```python

554

import torch

555

import torch.nn as nn

556

import torch.quantization as quant

557

558

# Define model

559

class QuantModel(nn.Module):

560

def __init__(self):

561

super().__init__()

562

self.quant = quant.QuantStub()

563

self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

564

self.relu1 = nn.ReLU()

565

self.conv2 = nn.Conv2d(32, 64, 3, padding=1)

566

self.relu2 = nn.ReLU()

567

self.pool = nn.AdaptiveAvgPool2d((1, 1))

568

self.fc = nn.Linear(64, 10)

569

self.dequant = quant.DeQuantStub()

570

571

def forward(self, x):

572

x = self.quant(x)

573

x = self.relu1(self.conv1(x))

574

x = self.relu2(self.conv2(x))

575

x = self.pool(x)

576

x = x.flatten(1)

577

x = self.fc(x)

578

x = self.dequant(x)

579

return x

580

581

model = QuantModel()

582

model.eval()

583

584

# Dynamic quantization

585

quantized_model = quant.quantize_dynamic(

586

model, {nn.Linear}, dtype=torch.qint8

587

)

588

589

# Post-training static quantization

590

model.qconfig = quant.get_default_qconfig('fbgemm')

591

prepared_model = quant.prepare(model)

592

593

# Calibration (example data)

594

for _ in range(10):

595

calibration_data = torch.randn(1, 3, 32, 32)

596

prepared_model(calibration_data)

597

598

# Convert to quantized model

599

quantized_static_model = quant.convert(prepared_model)

600

601

print("Quantization completed")

602

print(f"Original model size: {sum(p.numel() for p in model.parameters())}")

603

print(f"Quantized model parameters: {sum(p.numel() for p in quantized_model.parameters())}")

604

```

605

606

### ONNX Export

607

608

```python

609

import torch

610

import torch.nn as nn

611

import torch.onnx

612

613

# Define model

614

class ONNXModel(nn.Module):

615

def __init__(self):

616

super().__init__()

617

self.backbone = nn.Sequential(

618

nn.Conv2d(3, 64, 7, stride=2, padding=3),

619

nn.BatchNorm2d(64),

620

nn.ReLU(inplace=True),

621

nn.AdaptiveAvgPool2d((1, 1)),

622

nn.Flatten(),

623

nn.Linear(64, 1000)

624

)

625

626

def forward(self, x):

627

return self.backbone(x)

628

629

model = ONNXModel()

630

model.eval()

631

632

# Example input

633

dummy_input = torch.randn(1, 3, 224, 224)

634

635

# Export to ONNX

636

torch.onnx.export(

637

model,

638

dummy_input,

639

"model.onnx",

640

export_params=True,

641

opset_version=11,

642

do_constant_folding=True,

643

input_names=['input'],

644

output_names=['output'],

645

dynamic_axes={

646

'input': {0: 'batch_size'},

647

'output': {0: 'batch_size'}

648

}

649

)

650

651

print("ONNX export completed")

652

```

653

654

### FX Graph Manipulation

655

656

```python

657

import torch

658

import torch.nn as nn

659

import torch.fx as fx

660

661

# Define model

662

class FXModel(nn.Module):

663

def __init__(self):

664

super().__init__()

665

self.conv1 = nn.Conv2d(3, 32, 3)

666

self.conv2 = nn.Conv2d(32, 64, 3)

667

self.relu = nn.ReLU()

668

self.pool = nn.AdaptiveAvgPool2d((1, 1))

669

self.fc = nn.Linear(64, 10)

670

671

def forward(self, x):

672

x = self.relu(self.conv1(x))

673

x = self.relu(self.conv2(x))

674

x = self.pool(x)

675

x = x.flatten(1)

676

x = self.fc(x)

677

return x

678

679

# Symbolic tracing

680

model = FXModel()

681

traced = fx.symbolic_trace(model)

682

683

# Print graph

684

print("Original graph:")

685

traced.graph.print_tabular()

686

687

# Graph manipulation - replace ReLU with GELU

688

for node in traced.graph.nodes:

689

if node.target == torch.relu:

690

with traced.graph.inserting_after(node):

691

new_node = traced.graph.call_function(torch.nn.functional.gelu, args=(node.args[0],))

692

node.replace_all_uses_with(new_node)

693

traced.graph.erase_node(node)

694

695

# Recompile

696

traced.recompile()

697

698

print("\nModified graph:")

699

traced.graph.print_tabular()

700

701

# Test modified model

702

test_input = torch.randn(1, 3, 32, 32)

703

output = traced(test_input)

704

print(f"FX transformation completed, output shape: {output.shape}")

705

```

706

707

### Mixed Precision Training

708

709

```python

710

import torch

711

import torch.nn as nn

712

import torch.optim as optim

713

from torch.cuda.amp import autocast, GradScaler

714

715

# Define model and training setup

716

model = nn.Sequential(

717

nn.Linear(1000, 500),

718

nn.ReLU(),

719

nn.Linear(500, 100),

720

nn.ReLU(),

721

nn.Linear(100, 10)

722

).cuda()

723

724

optimizer = optim.Adam(model.parameters(), lr=0.001)

725

criterion = nn.CrossEntropyLoss()

726

scaler = GradScaler()

727

728

# Training loop with mixed precision

729

model.train()

730

for epoch in range(5):

731

for batch_idx in range(100): # Simulate 100 batches

732

# Generate dummy data

733

data = torch.randn(32, 1000).cuda()

734

targets = torch.randint(0, 10, (32,)).cuda()

735

736

optimizer.zero_grad()

737

738

# Forward pass with autocast

739

with autocast():

740

outputs = model(data)

741

loss = criterion(outputs, targets)

742

743

# Backward pass with gradient scaling

744

scaler.scale(loss).backward()

745

scaler.step(optimizer)

746

scaler.update()

747

748

if batch_idx % 25 == 0:

749

print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Scale: {scaler.get_scale()}")

750

751

print("Mixed precision training completed")

752

```