or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

datasets.mdindex.mdio.mdmodels.mdops.mdtransforms.mdtv_tensors.mdutils.md

models.mddocs/

0

# Models

1

2

TorchVision provides pre-trained neural network models for various computer vision tasks including image classification, object detection, instance segmentation, semantic segmentation, keypoint detection, and video analysis. All models support both training and evaluation modes with optional pre-trained weights.

3

4

## Capabilities

5

6

### Model Management API

7

8

High-level API for discovering and loading models with configuration.

9

10

```python { .api }

11

def get_model(name: str, **config) -> torch.nn.Module:

12

"""

13

Get model by name with configuration.

14

15

Args:

16

name (str): Model name

17

**config: Model-specific configuration parameters

18

19

Returns:

20

torch.nn.Module: Instantiated model

21

"""

22

23

def get_model_builder(name: str):

24

"""

25

Get model builder function by name.

26

27

Args:

28

name (str): Model name

29

30

Returns:

31

Callable: Model builder function

32

"""

33

34

def get_model_weights(name: str):

35

"""

36

Get available weights for a model.

37

38

Args:

39

name (str): Model name

40

41

Returns:

42

Dict of available weights

43

"""

44

45

def get_weight(name: str):

46

"""

47

Get specific weight by name.

48

49

Args:

50

name (str): Weight name

51

52

Returns:

53

Weight object

54

"""

55

56

def list_models() -> list[str]:

57

"""

58

List all available models.

59

60

Returns:

61

list[str]: List of model names

62

"""

63

64

class Weights:

65

"""Dataclass for model weights metadata."""

66

url: str

67

transforms: callable

68

meta: dict

69

70

class WeightsEnum:

71

"""Enum base class for model weights."""

72

```

73

74

### Classification Models

75

76

#### ResNet Family

77

78

Deep residual networks with skip connections for image classification.

79

80

```python { .api }

81

class ResNet(torch.nn.Module):

82

"""

83

ResNet architecture implementation.

84

85

Args:

86

block: Block type (BasicBlock or Bottleneck)

87

layers (list): Number of blocks per layer

88

num_classes (int): Number of classes for classification

89

zero_init_residual (bool): Zero-initialize residual connections

90

groups (int): Number of groups for grouped convolution

91

width_per_group (int): Width per group for grouped convolution

92

replace_stride_with_dilation (list): Replace stride with dilation

93

norm_layer: Normalization layer

94

"""

95

96

def resnet18(weights=None, progress: bool = True, **kwargs) -> ResNet:

97

"""

98

ResNet-18 model.

99

100

Args:

101

weights: Pre-trained weights to use (None, 'DEFAULT', or specific weights)

102

progress (bool): Show download progress bar

103

**kwargs: Additional arguments passed to ResNet

104

105

Returns:

106

ResNet: ResNet-18 model

107

"""

108

109

def resnet34(weights=None, progress: bool = True, **kwargs) -> ResNet:

110

"""ResNet-34 model."""

111

112

def resnet50(weights=None, progress: bool = True, **kwargs) -> ResNet:

113

"""ResNet-50 model."""

114

115

def resnet101(weights=None, progress: bool = True, **kwargs) -> ResNet:

116

"""ResNet-101 model."""

117

118

def resnet152(weights=None, progress: bool = True, **kwargs) -> ResNet:

119

"""ResNet-152 model."""

120

121

def resnext50_32x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:

122

"""ResNeXt-50 32x4d model with grouped convolutions."""

123

124

def resnext101_32x8d(weights=None, progress: bool = True, **kwargs) -> ResNet:

125

"""ResNeXt-101 32x8d model with grouped convolutions."""

126

127

def resnext101_64x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:

128

"""ResNeXt-101 64x4d model with grouped convolutions."""

129

130

def wide_resnet50_2(weights=None, progress: bool = True, **kwargs) -> ResNet:

131

"""Wide ResNet-50-2 model with wider channels."""

132

133

def wide_resnet101_2(weights=None, progress: bool = True, **kwargs) -> ResNet:

134

"""Wide ResNet-101-2 model with wider channels."""

135

```

136

137

#### Vision Transformer

138

139

Transformer-based models for image classification using patch embeddings.

140

141

```python { .api }

142

class VisionTransformer(torch.nn.Module):

143

"""

144

Vision Transformer architecture.

145

146

Args:

147

image_size (int): Input image size

148

patch_size (int): Size of image patches

149

num_layers (int): Number of transformer layers

150

num_heads (int): Number of attention heads

151

hidden_dim (int): Hidden dimension size

152

mlp_dim (int): MLP dimension size

153

dropout (float): Dropout rate

154

attention_dropout (float): Attention dropout rate

155

num_classes (int): Number of classes

156

representation_size: Optional representation layer size

157

norm_layer: Normalization layer

158

conv_stem_configs: Optional convolutional stem configuration

159

"""

160

161

def vit_b_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:

162

"""

163

ViT-Base/16 model with 16x16 patches.

164

165

Args:

166

weights: Pre-trained weights to use

167

progress (bool): Show download progress bar

168

**kwargs: Additional arguments

169

170

Returns:

171

VisionTransformer: ViT-Base/16 model

172

"""

173

174

def vit_b_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:

175

"""ViT-Base/32 model with 32x32 patches."""

176

177

def vit_l_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:

178

"""ViT-Large/16 model with 16x16 patches."""

179

180

def vit_l_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:

181

"""ViT-Large/32 model with 32x32 patches."""

182

183

def vit_h_14(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:

184

"""ViT-Huge/14 model with 14x14 patches."""

185

```

186

187

#### EfficientNet Family

188

189

Efficient convolutional networks optimized for accuracy and efficiency.

190

191

```python { .api }

192

class EfficientNet(torch.nn.Module):

193

"""

194

EfficientNet architecture with compound scaling.

195

196

Args:

197

inverted_residual_setting: Network structure configuration

198

dropout (float): Dropout rate

199

stochastic_depth_prob (float): Stochastic depth probability

200

num_classes (int): Number of classes

201

norm_layer: Normalization layer

202

last_channel: Optional last channel override

203

"""

204

205

def efficientnet_b0(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

206

"""EfficientNet-B0 model."""

207

208

def efficientnet_b1(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

209

"""EfficientNet-B1 model."""

210

211

def efficientnet_b2(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

212

"""EfficientNet-B2 model."""

213

214

def efficientnet_b3(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

215

"""EfficientNet-B3 model."""

216

217

def efficientnet_b4(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

218

"""EfficientNet-B4 model."""

219

220

def efficientnet_b5(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

221

"""EfficientNet-B5 model."""

222

223

def efficientnet_b6(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

224

"""EfficientNet-B6 model."""

225

226

def efficientnet_b7(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

227

"""EfficientNet-B7 model."""

228

229

def efficientnet_v2_s(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

230

"""EfficientNetV2-Small model with improved training and scaling."""

231

232

def efficientnet_v2_m(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

233

"""EfficientNetV2-Medium model."""

234

235

def efficientnet_v2_l(weights=None, progress: bool = True, **kwargs) -> EfficientNet:

236

"""EfficientNetV2-Large model."""

237

```

238

239

#### MobileNet Family

240

241

Lightweight models designed for mobile and embedded devices.

242

243

```python { .api }

244

class MobileNetV2(torch.nn.Module):

245

"""

246

MobileNetV2 architecture with inverted residuals and linear bottlenecks.

247

248

Args:

249

num_classes (int): Number of classes

250

width_mult (float): Width multiplier for channels

251

inverted_residual_setting: Optional network structure override

252

round_nearest (int): Round channels to nearest multiple

253

block: Block type for inverted residuals

254

norm_layer: Normalization layer

255

dropout (float): Dropout rate

256

"""

257

258

class MobileNetV3(torch.nn.Module):

259

"""

260

MobileNetV3 architecture with squeeze-and-excitation modules.

261

262

Args:

263

inverted_residual_setting: Network structure configuration

264

last_channel (int): Number of channels in final layer

265

num_classes (int): Number of classes

266

block: Block type for inverted residuals

267

norm_layer: Normalization layer

268

dropout (float): Dropout rate

269

"""

270

271

def mobilenet_v2(weights=None, progress: bool = True, **kwargs) -> MobileNetV2:

272

"""

273

MobileNetV2 model.

274

275

Args:

276

weights: Pre-trained weights to use

277

progress (bool): Show download progress bar

278

**kwargs: Additional arguments

279

280

Returns:

281

MobileNetV2: MobileNetV2 model

282

"""

283

284

def mobilenet_v3_large(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:

285

"""MobileNetV3-Large model."""

286

287

def mobilenet_v3_small(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:

288

"""MobileNetV3-Small model."""

289

```

290

291

#### Other Classification Models

292

293

Additional popular classification architectures.

294

295

```python { .api }

296

class AlexNet(torch.nn.Module):

297

"""AlexNet architecture for image classification."""

298

299

def alexnet(weights=None, progress: bool = True, **kwargs) -> AlexNet:

300

"""AlexNet model."""

301

302

class VGG(torch.nn.Module):

303

"""VGG architecture with customizable depth."""

304

305

def vgg11(weights=None, progress: bool = True, **kwargs) -> VGG:

306

"""VGG 11-layer model."""

307

308

def vgg11_bn(weights=None, progress: bool = True, **kwargs) -> VGG:

309

"""VGG 11-layer model with batch normalization."""

310

311

def vgg13(weights=None, progress: bool = True, **kwargs) -> VGG:

312

"""VGG 13-layer model."""

313

314

def vgg13_bn(weights=None, progress: bool = True, **kwargs) -> VGG:

315

"""VGG 13-layer model with batch normalization."""

316

317

def vgg16(weights=None, progress: bool = True, **kwargs) -> VGG:

318

"""VGG 16-layer model."""

319

320

def vgg16_bn(weights=None, progress: bool = True, **kwargs) -> VGG:

321

"""VGG 16-layer model with batch normalization."""

322

323

def vgg19(weights=None, progress: bool = True, **kwargs) -> VGG:

324

"""VGG 19-layer model."""

325

326

def vgg19_bn(weights=None, progress: bool = True, **kwargs) -> VGG:

327

"""VGG 19-layer model with batch normalization."""

328

329

class DenseNet(torch.nn.Module):

330

"""DenseNet architecture with dense connections."""

331

332

def densenet121(weights=None, progress: bool = True, **kwargs) -> DenseNet:

333

"""DenseNet-121 model."""

334

335

def densenet161(weights=None, progress: bool = True, **kwargs) -> DenseNet:

336

"""DenseNet-161 model."""

337

338

def densenet169(weights=None, progress: bool = True, **kwargs) -> DenseNet:

339

"""DenseNet-169 model."""

340

341

def densenet201(weights=None, progress: bool = True, **kwargs) -> DenseNet:

342

"""DenseNet-201 model."""

343

344

class Inception3(torch.nn.Module):

345

"""Inception v3 architecture."""

346

347

def inception_v3(weights=None, progress: bool = True, **kwargs) -> Inception3:

348

"""Inception v3 model."""

349

350

class GoogLeNet(torch.nn.Module):

351

"""GoogLeNet architecture with inception modules."""

352

353

def googlenet(weights=None, progress: bool = True, **kwargs) -> GoogLeNet:

354

"""GoogLeNet model."""

355

356

class ConvNeXt(torch.nn.Module):

357

"""ConvNeXt architecture with modernized ResNet design."""

358

359

def convnext_tiny(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:

360

"""ConvNeXt Tiny model."""

361

362

def convnext_small(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:

363

"""ConvNeXt Small model."""

364

365

def convnext_base(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:

366

"""ConvNeXt Base model."""

367

368

def convnext_large(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:

369

"""ConvNeXt Large model."""

370

371

class SwinTransformer(torch.nn.Module):

372

"""Swin Transformer with hierarchical feature maps."""

373

374

def swin_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:

375

"""Swin Transformer Tiny model."""

376

377

def swin_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:

378

"""Swin Transformer Small model."""

379

380

def swin_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:

381

"""Swin Transformer Base model."""

382

383

class MaxVit(torch.nn.Module):

384

"""MaxVit architecture combining convolution and attention."""

385

386

def maxvit_t(weights=None, progress: bool = True, **kwargs) -> MaxVit:

387

"""MaxVit Tiny model."""

388

```

389

390

### Object Detection Models

391

392

#### Two-Stage Detectors

393

394

Region-based convolutional neural networks for object detection.

395

396

```python { .api }

397

class FasterRCNN(torch.nn.Module):

398

"""

399

Faster R-CNN model for object detection.

400

401

Args:

402

backbone: Feature extraction backbone

403

num_classes: Number of classes (including background)

404

min_size: Minimum image size for rescaling

405

max_size: Maximum image size for rescaling

406

image_mean: Mean for image normalization

407

image_std: Standard deviation for image normalization

408

rpn_anchor_generator: RPN anchor generator

409

rpn_head: RPN head

410

rpn_pre_nms_top_n_train: RPN pre-NMS top-k (training)

411

rpn_pre_nms_top_n_test: RPN pre-NMS top-k (testing)

412

rpn_post_nms_top_n_train: RPN post-NMS top-k (training)

413

rpn_post_nms_top_n_test: RPN post-NMS top-k (testing)

414

rpn_nms_thresh: RPN NMS threshold

415

rpn_fg_iou_thresh: RPN foreground IoU threshold

416

rpn_bg_iou_thresh: RPN background IoU threshold

417

rpn_batch_size_per_image: RPN batch size per image

418

rpn_positive_fraction: RPN positive fraction

419

box_roi_pool: RoI pooling layer for boxes

420

box_head: Box head

421

box_predictor: Box predictor

422

box_score_thresh: Box score threshold for inference

423

box_nms_thresh: Box NMS threshold

424

box_detections_per_img: Maximum detections per image

425

box_fg_iou_thresh: Box foreground IoU threshold

426

box_bg_iou_thresh: Box background IoU threshold

427

box_batch_size_per_image: Box batch size per image

428

box_positive_fraction: Box positive fraction

429

bbox_reg_weights: Bounding box regression weights

430

"""

431

432

def fasterrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:

433

"""

434

Faster R-CNN model with ResNet-50-FPN backbone.

435

436

Args:

437

weights: Pre-trained weights to use

438

progress (bool): Show download progress bar

439

num_classes (int): Number of classes (overrides default)

440

weights_backbone: Backbone weights to use

441

trainable_backbone_layers (int): Number of trainable backbone layers

442

**kwargs: Additional arguments

443

444

Returns:

445

FasterRCNN: Faster R-CNN model

446

"""

447

448

def fasterrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:

449

"""Faster R-CNN model with ResNet-50-FPN v2 backbone."""

450

451

def fasterrcnn_mobilenet_v3_large_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:

452

"""Faster R-CNN model with MobileNetV3-Large-FPN backbone."""

453

454

def fasterrcnn_mobilenet_v3_large_320_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:

455

"""Faster R-CNN model with MobileNetV3-Large-320-FPN backbone."""

456

```

457

458

#### Instance Segmentation Models

459

460

Models for simultaneous object detection and instance segmentation.

461

462

```python { .api }

463

class MaskRCNN(torch.nn.Module):

464

"""

465

Mask R-CNN model for instance segmentation.

466

Extends Faster R-CNN with mask prediction branch.

467

468

Args:

469

backbone: Feature extraction backbone

470

num_classes: Number of classes (including background)

471

# ... (inherits all FasterRCNN parameters)

472

mask_roi_pool: RoI pooling layer for masks

473

mask_head: Mask head

474

mask_predictor: Mask predictor

475

"""

476

477

def maskrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:

478

"""

479

Mask R-CNN model with ResNet-50-FPN backbone.

480

481

Args:

482

weights: Pre-trained weights to use

483

progress (bool): Show download progress bar

484

num_classes (int): Number of classes (overrides default)

485

weights_backbone: Backbone weights to use

486

trainable_backbone_layers (int): Number of trainable backbone layers

487

**kwargs: Additional arguments

488

489

Returns:

490

MaskRCNN: Mask R-CNN model

491

"""

492

493

def maskrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:

494

"""Mask R-CNN model with ResNet-50-FPN v2 backbone."""

495

```

496

497

#### Keypoint Detection Models

498

499

Models for human pose estimation and keypoint detection.

500

501

```python { .api }

502

class KeypointRCNN(torch.nn.Module):

503

"""

504

Keypoint R-CNN model for keypoint detection.

505

Extends Faster R-CNN with keypoint prediction branch.

506

507

Args:

508

backbone: Feature extraction backbone

509

num_classes: Number of classes (including background)

510

num_keypoints: Number of keypoints to detect

511

# ... (inherits all FasterRCNN parameters)

512

keypoint_roi_pool: RoI pooling layer for keypoints

513

keypoint_head: Keypoint head

514

keypoint_predictor: Keypoint predictor

515

"""

516

517

def keypointrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, num_keypoints=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> KeypointRCNN:

518

"""

519

Keypoint R-CNN model with ResNet-50-FPN backbone.

520

521

Args:

522

weights: Pre-trained weights to use

523

progress (bool): Show download progress bar

524

num_classes (int): Number of classes (overrides default)

525

num_keypoints (int): Number of keypoints (overrides default)

526

weights_backbone: Backbone weights to use

527

trainable_backbone_layers (int): Number of trainable backbone layers

528

**kwargs: Additional arguments

529

530

Returns:

531

KeypointRCNN: Keypoint R-CNN model

532

"""

533

```

534

535

#### Single-Shot Detectors

536

537

One-stage object detection models for faster inference.

538

539

```python { .api }

540

class RetinaNet(torch.nn.Module):

541

"""

542

RetinaNet model with focal loss for object detection.

543

544

Args:

545

backbone: Feature extraction backbone

546

num_classes: Number of classes

547

min_size: Minimum image size for rescaling

548

max_size: Maximum image size for rescaling

549

image_mean: Mean for image normalization

550

image_std: Standard deviation for image normalization

551

anchor_generator: Anchor generator

552

head: Detection head

553

score_thresh: Score threshold for inference

554

nms_thresh: NMS threshold

555

detections_per_img: Maximum detections per image

556

fg_iou_thresh: Foreground IoU threshold

557

bg_iou_thresh: Background IoU threshold

558

topk_candidates: Top-k candidates to keep

559

"""

560

561

def retinanet_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:

562

"""RetinaNet model with ResNet-50-FPN backbone."""

563

564

def retinanet_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:

565

"""RetinaNet model with ResNet-50-FPN v2 backbone."""

566

567

class SSD(torch.nn.Module):

568

"""Single Shot MultiBox Detector model."""

569

570

def ssd300_vgg16(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:

571

"""SSD300 model with VGG-16 backbone."""

572

573

def ssdlite320_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:

574

"""SSDLite320 model with MobileNetV3-Large backbone."""

575

576

class FCOS(torch.nn.Module):

577

"""FCOS (Fully Convolutional One-Stage) object detector."""

578

579

def fcos_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCOS:

580

"""FCOS model with ResNet-50-FPN backbone."""

581

```

582

583

### Semantic Segmentation Models

584

585

Pixel-level classification models for semantic segmentation.

586

587

```python { .api }

588

class FCN(torch.nn.Module):

589

"""

590

Fully Convolutional Network for semantic segmentation.

591

592

Args:

593

backbone: Feature extraction backbone

594

classifier: Classification head

595

aux_classifier: Auxiliary classification head

596

"""

597

598

def fcn_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:

599

"""FCN model with ResNet-50 backbone."""

600

601

def fcn_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:

602

"""FCN model with ResNet-101 backbone."""

603

604

class DeepLabV3(torch.nn.Module):

605

"""

606

DeepLabV3 model with atrous spatial pyramid pooling.

607

608

Args:

609

backbone: Feature extraction backbone

610

classifier: Classification head with ASPP

611

aux_classifier: Auxiliary classification head

612

"""

613

614

def deeplabv3_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:

615

"""DeepLabV3 model with ResNet-50 backbone."""

616

617

def deeplabv3_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:

618

"""DeepLabV3 model with ResNet-101 backbone."""

619

620

def deeplabv3_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:

621

"""DeepLabV3 model with MobileNetV3-Large backbone."""

622

623

class LRASPP(torch.nn.Module):

624

"""

625

Lite R-ASPP model for fast semantic segmentation.

626

627

Args:

628

backbone: Feature extraction backbone

629

low_channels: Number of low-level feature channels

630

high_channels: Number of high-level feature channels

631

num_classes: Number of classes

632

inter_channels: Number of intermediate channels

633

"""

634

635

def lraspp_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> LRASPP:

636

"""LRASPP model with MobileNetV3-Large backbone."""

637

```

638

639

### Video Models

640

641

Models for video understanding and temporal analysis.

642

643

```python { .api }

644

class VideoResNet(torch.nn.Module):

645

"""

646

3D ResNet architecture for video classification.

647

648

Args:

649

block: 3D block type

650

conv_makers: Convolution configuration for each layer

651

layers: Number of blocks per layer

652

stem: Stem configuration

653

num_classes: Number of classes

654

zero_init_residual: Zero-initialize residual connections

655

"""

656

657

def r3d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:

658

"""3D ResNet-18 for video classification."""

659

660

def mc3_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:

661

"""Mixed Convolution 3D ResNet-18."""

662

663

def r2plus1d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:

664

"""R(2+1)D ResNet-18 with factorized convolutions."""

665

666

class S3D(torch.nn.Module):

667

"""Separable 3D CNN architecture."""

668

669

def s3d(weights=None, progress: bool = True, **kwargs) -> S3D:

670

"""S3D model for video classification."""

671

672

class MViT(torch.nn.Module):

673

"""Multiscale Vision Transformer for video understanding."""

674

675

def mvit_v1_b(weights=None, progress: bool = True, **kwargs) -> MViT:

676

"""MViTv1-Base model."""

677

678

def mvit_v2_s(weights=None, progress: bool = True, **kwargs) -> MViT:

679

"""MViTv2-Small model."""

680

681

class SwinTransformer3D(torch.nn.Module):

682

"""3D Swin Transformer for video analysis."""

683

684

def swin3d_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:

685

"""Swin3D Tiny model."""

686

687

def swin3d_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:

688

"""Swin3D Small model."""

689

690

def swin3d_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:

691

"""Swin3D Base model."""

692

```

693

694

### Optical Flow Models

695

696

Models for estimating optical flow between video frames.

697

698

```python { .api }

699

class RAFT(torch.nn.Module):

700

"""

701

RAFT (Recurrent All-Pairs Field Transforms) optical flow model.

702

703

Args:

704

feature_encoder: Feature extraction encoder

705

context_encoder: Context extraction encoder

706

correlation_block: Correlation block for feature matching

707

update_block: GRU-based update block

708

mask_predictor: Flow mask predictor

709

"""

710

711

def raft_large(weights=None, progress: bool = True, **kwargs) -> RAFT:

712

"""RAFT Large model for optical flow estimation."""

713

714

def raft_small(weights=None, progress: bool = True, **kwargs) -> RAFT:

715

"""RAFT Small model for optical flow estimation."""

716

```

717

718

### Quantized Models

719

720

Quantized versions of popular models for efficient inference.

721

722

```python { .api }

723

class QuantizableResNet(torch.nn.Module):

724

"""Quantizable ResNet architecture."""

725

726

# Quantized classification models

727

def resnet18(weights=None, progress: bool = True, quantize: bool = False, **kwargs):

728

"""Quantized ResNet-18 model."""

729

730

def resnet50(weights=None, progress: bool = True, quantize: bool = False, **kwargs):

731

"""Quantized ResNet-50 model."""

732

733

class QuantizableMobileNetV2(torch.nn.Module):

734

"""Quantizable MobileNetV2 architecture."""

735

736

def mobilenet_v2(weights=None, progress: bool = True, quantize: bool = False, **kwargs):

737

"""Quantized MobileNetV2 model."""

738

739

class QuantizableMobileNetV3(torch.nn.Module):

740

"""Quantizable MobileNetV3 architecture."""

741

742

def mobilenet_v3_large(weights=None, progress: bool = True, quantize: bool = False, **kwargs):

743

"""Quantized MobileNetV3-Large model."""

744

```

745

746

### Feature Extraction

747

748

Utilities for extracting intermediate features from pre-trained models.

749

750

```python { .api }

751

def create_feature_extractor(model: torch.nn.Module, return_nodes: dict, train_return_nodes=None, eval_return_nodes=None, tracer_kwargs=None, suppress_diff_warning: bool = False):

752

"""

753

Creates a feature extractor from any model.

754

755

Args:

756

model (torch.nn.Module): Model to extract features from

757

return_nodes (dict): Dict mapping node names to user-specified keys

758

train_return_nodes (dict, optional): Nodes to return during training

759

eval_return_nodes (dict, optional): Nodes to return during evaluation

760

tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer

761

suppress_diff_warning (bool): Suppress difference warning

762

763

Returns:

764

FeatureExtractor: Model wrapper that returns intermediate features

765

"""

766

767

def get_graph_node_names(model: torch.nn.Module, tracer_kwargs=None, suppress_diff_warning: bool = False):

768

"""

769

Gets graph node names for feature extraction.

770

771

Args:

772

model (torch.nn.Module): Model to analyze

773

tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer

774

suppress_diff_warning (bool): Suppress difference warning

775

776

Returns:

777

tuple: (train_nodes, eval_nodes) containing node names

778

"""

779

```

780

781

## Usage Examples

782

783

### Loading Pre-trained Models

784

785

```python

786

import torchvision.models as models

787

import torch

788

789

# Load a pre-trained ResNet-50

790

model = models.resnet50(weights='DEFAULT')

791

model.eval()

792

793

# Load model without weights

794

model = models.resnet50(weights=None)

795

796

# Load with specific weights

797

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

798

799

# Modify for different number of classes

800

model = models.resnet50(weights='DEFAULT')

801

num_classes = 10

802

model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

803

```

804

805

### Object Detection

806

807

```python

808

import torchvision.models as models

809

import torchvision.transforms as transforms

810

from PIL import Image

811

812

# Load pre-trained Faster R-CNN

813

model = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')

814

model.eval()

815

816

# Prepare image

817

transform = transforms.Compose([transforms.ToTensor()])

818

image = Image.open('image.jpg')

819

image_tensor = transform(image)

820

821

# Inference

822

with torch.no_grad():

823

predictions = model([image_tensor])

824

825

# Access results

826

boxes = predictions[0]['boxes']

827

scores = predictions[0]['scores']

828

labels = predictions[0]['labels']

829

```

830

831

### Feature Extraction

832

833

```python

834

import torchvision.models as models

835

from torchvision.models.feature_extraction import create_feature_extractor

836

837

# Load pre-trained model

838

model = models.resnet50(weights='DEFAULT')

839

840

# Create feature extractor

841

return_nodes = {

842

'layer1.2.conv3': 'layer1',

843

'layer2.3.conv3': 'layer2',

844

'layer3.5.conv3': 'layer3',

845

'layer4.2.conv3': 'layer4'

846

}

847

848

feature_extractor = create_feature_extractor(model, return_nodes)

849

850

# Extract features

851

with torch.no_grad():

852

features = feature_extractor(input_tensor)

853

854

# Access extracted features

855

layer1_features = features['layer1']

856

layer2_features = features['layer2']

857

```

858

859

### Video Classification

860

861

```python

862

import torchvision.models.video as video_models

863

import torch

864

865

# Load pre-trained video model

866

model = video_models.r3d_18(weights='DEFAULT')

867

model.eval()

868

869

# Prepare video tensor (batch_size, channels, frames, height, width)

870

video_tensor = torch.randn(1, 3, 16, 224, 224)

871

872

# Inference

873

with torch.no_grad():

874

predictions = model(video_tensor)

875

876

predicted_class = torch.argmax(predictions, dim=1)

877

```