0
# Models
1
2
TorchVision provides pre-trained neural network models for various computer vision tasks including image classification, object detection, instance segmentation, semantic segmentation, keypoint detection, and video analysis. All models support both training and evaluation modes with optional pre-trained weights.
3
4
## Capabilities
5
6
### Model Management API
7
8
High-level API for discovering and loading models with configuration.
9
10
```python { .api }
11
def get_model(name: str, **config) -> torch.nn.Module:
12
"""
13
Get model by name with configuration.
14
15
Args:
16
name (str): Model name
17
**config: Model-specific configuration parameters
18
19
Returns:
20
torch.nn.Module: Instantiated model
21
"""
22
23
def get_model_builder(name: str):
24
"""
25
Get model builder function by name.
26
27
Args:
28
name (str): Model name
29
30
Returns:
31
Callable: Model builder function
32
"""
33
34
def get_model_weights(name: str):
35
"""
36
Get available weights for a model.
37
38
Args:
39
name (str): Model name
40
41
Returns:
42
Dict of available weights
43
"""
44
45
def get_weight(name: str):
46
"""
47
Get specific weight by name.
48
49
Args:
50
name (str): Weight name
51
52
Returns:
53
Weight object
54
"""
55
56
def list_models() -> list[str]:
57
"""
58
List all available models.
59
60
Returns:
61
list[str]: List of model names
62
"""
63
64
class Weights:
65
"""Dataclass for model weights metadata."""
66
url: str
67
transforms: callable
68
meta: dict
69
70
class WeightsEnum:
71
"""Enum base class for model weights."""
72
```
73
74
### Classification Models
75
76
#### ResNet Family
77
78
Deep residual networks with skip connections for image classification.
79
80
```python { .api }
81
class ResNet(torch.nn.Module):
82
"""
83
ResNet architecture implementation.
84
85
Args:
86
block: Block type (BasicBlock or Bottleneck)
87
layers (list): Number of blocks per layer
88
num_classes (int): Number of classes for classification
89
zero_init_residual (bool): Zero-initialize residual connections
90
groups (int): Number of groups for grouped convolution
91
width_per_group (int): Width per group for grouped convolution
92
replace_stride_with_dilation (list): Replace stride with dilation
93
norm_layer: Normalization layer
94
"""
95
96
def resnet18(weights=None, progress: bool = True, **kwargs) -> ResNet:
97
"""
98
ResNet-18 model.
99
100
Args:
101
weights: Pre-trained weights to use (None, 'DEFAULT', or specific weights)
102
progress (bool): Show download progress bar
103
**kwargs: Additional arguments passed to ResNet
104
105
Returns:
106
ResNet: ResNet-18 model
107
"""
108
109
def resnet34(weights=None, progress: bool = True, **kwargs) -> ResNet:
110
"""ResNet-34 model."""
111
112
def resnet50(weights=None, progress: bool = True, **kwargs) -> ResNet:
113
"""ResNet-50 model."""
114
115
def resnet101(weights=None, progress: bool = True, **kwargs) -> ResNet:
116
"""ResNet-101 model."""
117
118
def resnet152(weights=None, progress: bool = True, **kwargs) -> ResNet:
119
"""ResNet-152 model."""
120
121
def resnext50_32x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:
122
"""ResNeXt-50 32x4d model with grouped convolutions."""
123
124
def resnext101_32x8d(weights=None, progress: bool = True, **kwargs) -> ResNet:
125
"""ResNeXt-101 32x8d model with grouped convolutions."""
126
127
def resnext101_64x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:
128
"""ResNeXt-101 64x4d model with grouped convolutions."""
129
130
def wide_resnet50_2(weights=None, progress: bool = True, **kwargs) -> ResNet:
131
"""Wide ResNet-50-2 model with wider channels."""
132
133
def wide_resnet101_2(weights=None, progress: bool = True, **kwargs) -> ResNet:
134
"""Wide ResNet-101-2 model with wider channels."""
135
```
136
137
#### Vision Transformer
138
139
Transformer-based models for image classification using patch embeddings.
140
141
```python { .api }
142
class VisionTransformer(torch.nn.Module):
143
"""
144
Vision Transformer architecture.
145
146
Args:
147
image_size (int): Input image size
148
patch_size (int): Size of image patches
149
num_layers (int): Number of transformer layers
150
num_heads (int): Number of attention heads
151
hidden_dim (int): Hidden dimension size
152
mlp_dim (int): MLP dimension size
153
dropout (float): Dropout rate
154
attention_dropout (float): Attention dropout rate
155
num_classes (int): Number of classes
156
representation_size: Optional representation layer size
157
norm_layer: Normalization layer
158
conv_stem_configs: Optional convolutional stem configuration
159
"""
160
161
def vit_b_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
162
"""
163
ViT-Base/16 model with 16x16 patches.
164
165
Args:
166
weights: Pre-trained weights to use
167
progress (bool): Show download progress bar
168
**kwargs: Additional arguments
169
170
Returns:
171
VisionTransformer: ViT-Base/16 model
172
"""
173
174
def vit_b_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
175
"""ViT-Base/32 model with 32x32 patches."""
176
177
def vit_l_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
178
"""ViT-Large/16 model with 16x16 patches."""
179
180
def vit_l_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
181
"""ViT-Large/32 model with 32x32 patches."""
182
183
def vit_h_14(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
184
"""ViT-Huge/14 model with 14x14 patches."""
185
```
186
187
#### EfficientNet Family
188
189
Efficient convolutional networks optimized for accuracy and efficiency.
190
191
```python { .api }
192
class EfficientNet(torch.nn.Module):
193
"""
194
EfficientNet architecture with compound scaling.
195
196
Args:
197
inverted_residual_setting: Network structure configuration
198
dropout (float): Dropout rate
199
stochastic_depth_prob (float): Stochastic depth probability
200
num_classes (int): Number of classes
201
norm_layer: Normalization layer
202
last_channel: Optional last channel override
203
"""
204
205
def efficientnet_b0(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
206
"""EfficientNet-B0 model."""
207
208
def efficientnet_b1(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
209
"""EfficientNet-B1 model."""
210
211
def efficientnet_b2(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
212
"""EfficientNet-B2 model."""
213
214
def efficientnet_b3(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
215
"""EfficientNet-B3 model."""
216
217
def efficientnet_b4(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
218
"""EfficientNet-B4 model."""
219
220
def efficientnet_b5(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
221
"""EfficientNet-B5 model."""
222
223
def efficientnet_b6(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
224
"""EfficientNet-B6 model."""
225
226
def efficientnet_b7(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
227
"""EfficientNet-B7 model."""
228
229
def efficientnet_v2_s(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
230
"""EfficientNetV2-Small model with improved training and scaling."""
231
232
def efficientnet_v2_m(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
233
"""EfficientNetV2-Medium model."""
234
235
def efficientnet_v2_l(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
236
"""EfficientNetV2-Large model."""
237
```
238
239
#### MobileNet Family
240
241
Lightweight models designed for mobile and embedded devices.
242
243
```python { .api }
244
class MobileNetV2(torch.nn.Module):
245
"""
246
MobileNetV2 architecture with inverted residuals and linear bottlenecks.
247
248
Args:
249
num_classes (int): Number of classes
250
width_mult (float): Width multiplier for channels
251
inverted_residual_setting: Optional network structure override
252
round_nearest (int): Round channels to nearest multiple
253
block: Block type for inverted residuals
254
norm_layer: Normalization layer
255
dropout (float): Dropout rate
256
"""
257
258
class MobileNetV3(torch.nn.Module):
259
"""
260
MobileNetV3 architecture with squeeze-and-excitation modules.
261
262
Args:
263
inverted_residual_setting: Network structure configuration
264
last_channel (int): Number of channels in final layer
265
num_classes (int): Number of classes
266
block: Block type for inverted residuals
267
norm_layer: Normalization layer
268
dropout (float): Dropout rate
269
"""
270
271
def mobilenet_v2(weights=None, progress: bool = True, **kwargs) -> MobileNetV2:
272
"""
273
MobileNetV2 model.
274
275
Args:
276
weights: Pre-trained weights to use
277
progress (bool): Show download progress bar
278
**kwargs: Additional arguments
279
280
Returns:
281
MobileNetV2: MobileNetV2 model
282
"""
283
284
def mobilenet_v3_large(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:
285
"""MobileNetV3-Large model."""
286
287
def mobilenet_v3_small(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:
288
"""MobileNetV3-Small model."""
289
```
290
291
#### Other Classification Models
292
293
Additional popular classification architectures.
294
295
```python { .api }
296
class AlexNet(torch.nn.Module):
297
"""AlexNet architecture for image classification."""
298
299
def alexnet(weights=None, progress: bool = True, **kwargs) -> AlexNet:
300
"""AlexNet model."""
301
302
class VGG(torch.nn.Module):
303
"""VGG architecture with customizable depth."""
304
305
def vgg11(weights=None, progress: bool = True, **kwargs) -> VGG:
306
"""VGG 11-layer model."""
307
308
def vgg11_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
309
"""VGG 11-layer model with batch normalization."""
310
311
def vgg13(weights=None, progress: bool = True, **kwargs) -> VGG:
312
"""VGG 13-layer model."""
313
314
def vgg13_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
315
"""VGG 13-layer model with batch normalization."""
316
317
def vgg16(weights=None, progress: bool = True, **kwargs) -> VGG:
318
"""VGG 16-layer model."""
319
320
def vgg16_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
321
"""VGG 16-layer model with batch normalization."""
322
323
def vgg19(weights=None, progress: bool = True, **kwargs) -> VGG:
324
"""VGG 19-layer model."""
325
326
def vgg19_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
327
"""VGG 19-layer model with batch normalization."""
328
329
class DenseNet(torch.nn.Module):
330
"""DenseNet architecture with dense connections."""
331
332
def densenet121(weights=None, progress: bool = True, **kwargs) -> DenseNet:
333
"""DenseNet-121 model."""
334
335
def densenet161(weights=None, progress: bool = True, **kwargs) -> DenseNet:
336
"""DenseNet-161 model."""
337
338
def densenet169(weights=None, progress: bool = True, **kwargs) -> DenseNet:
339
"""DenseNet-169 model."""
340
341
def densenet201(weights=None, progress: bool = True, **kwargs) -> DenseNet:
342
"""DenseNet-201 model."""
343
344
class Inception3(torch.nn.Module):
345
"""Inception v3 architecture."""
346
347
def inception_v3(weights=None, progress: bool = True, **kwargs) -> Inception3:
348
"""Inception v3 model."""
349
350
class GoogLeNet(torch.nn.Module):
351
"""GoogLeNet architecture with inception modules."""
352
353
def googlenet(weights=None, progress: bool = True, **kwargs) -> GoogLeNet:
354
"""GoogLeNet model."""
355
356
class ConvNeXt(torch.nn.Module):
357
"""ConvNeXt architecture with modernized ResNet design."""
358
359
def convnext_tiny(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
360
"""ConvNeXt Tiny model."""
361
362
def convnext_small(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
363
"""ConvNeXt Small model."""
364
365
def convnext_base(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
366
"""ConvNeXt Base model."""
367
368
def convnext_large(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
369
"""ConvNeXt Large model."""
370
371
class SwinTransformer(torch.nn.Module):
372
"""Swin Transformer with hierarchical feature maps."""
373
374
def swin_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
375
"""Swin Transformer Tiny model."""
376
377
def swin_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
378
"""Swin Transformer Small model."""
379
380
def swin_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
381
"""Swin Transformer Base model."""
382
383
class MaxVit(torch.nn.Module):
384
"""MaxVit architecture combining convolution and attention."""
385
386
def maxvit_t(weights=None, progress: bool = True, **kwargs) -> MaxVit:
387
"""MaxVit Tiny model."""
388
```
389
390
### Object Detection Models
391
392
#### Two-Stage Detectors
393
394
Region-based convolutional neural networks for object detection.
395
396
```python { .api }
397
class FasterRCNN(torch.nn.Module):
398
"""
399
Faster R-CNN model for object detection.
400
401
Args:
402
backbone: Feature extraction backbone
403
num_classes: Number of classes (including background)
404
min_size: Minimum image size for rescaling
405
max_size: Maximum image size for rescaling
406
image_mean: Mean for image normalization
407
image_std: Standard deviation for image normalization
408
rpn_anchor_generator: RPN anchor generator
409
rpn_head: RPN head
410
rpn_pre_nms_top_n_train: RPN pre-NMS top-k (training)
411
rpn_pre_nms_top_n_test: RPN pre-NMS top-k (testing)
412
rpn_post_nms_top_n_train: RPN post-NMS top-k (training)
413
rpn_post_nms_top_n_test: RPN post-NMS top-k (testing)
414
rpn_nms_thresh: RPN NMS threshold
415
rpn_fg_iou_thresh: RPN foreground IoU threshold
416
rpn_bg_iou_thresh: RPN background IoU threshold
417
rpn_batch_size_per_image: RPN batch size per image
418
rpn_positive_fraction: RPN positive fraction
419
box_roi_pool: RoI pooling layer for boxes
420
box_head: Box head
421
box_predictor: Box predictor
422
box_score_thresh: Box score threshold for inference
423
box_nms_thresh: Box NMS threshold
424
box_detections_per_img: Maximum detections per image
425
box_fg_iou_thresh: Box foreground IoU threshold
426
box_bg_iou_thresh: Box background IoU threshold
427
box_batch_size_per_image: Box batch size per image
428
box_positive_fraction: Box positive fraction
429
bbox_reg_weights: Bounding box regression weights
430
"""
431
432
def fasterrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
433
"""
434
Faster R-CNN model with ResNet-50-FPN backbone.
435
436
Args:
437
weights: Pre-trained weights to use
438
progress (bool): Show download progress bar
439
num_classes (int): Number of classes (overrides default)
440
weights_backbone: Backbone weights to use
441
trainable_backbone_layers (int): Number of trainable backbone layers
442
**kwargs: Additional arguments
443
444
Returns:
445
FasterRCNN: Faster R-CNN model
446
"""
447
448
def fasterrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
449
"""Faster R-CNN model with ResNet-50-FPN v2 backbone."""
450
451
def fasterrcnn_mobilenet_v3_large_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
452
"""Faster R-CNN model with MobileNetV3-Large-FPN backbone."""
453
454
def fasterrcnn_mobilenet_v3_large_320_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
455
"""Faster R-CNN model with MobileNetV3-Large-320-FPN backbone."""
456
```
457
458
#### Instance Segmentation Models
459
460
Models for simultaneous object detection and instance segmentation.
461
462
```python { .api }
463
class MaskRCNN(torch.nn.Module):
464
"""
465
Mask R-CNN model for instance segmentation.
466
Extends Faster R-CNN with mask prediction branch.
467
468
Args:
469
backbone: Feature extraction backbone
470
num_classes: Number of classes (including background)
471
# ... (inherits all FasterRCNN parameters)
472
mask_roi_pool: RoI pooling layer for masks
473
mask_head: Mask head
474
mask_predictor: Mask predictor
475
"""
476
477
def maskrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:
478
"""
479
Mask R-CNN model with ResNet-50-FPN backbone.
480
481
Args:
482
weights: Pre-trained weights to use
483
progress (bool): Show download progress bar
484
num_classes (int): Number of classes (overrides default)
485
weights_backbone: Backbone weights to use
486
trainable_backbone_layers (int): Number of trainable backbone layers
487
**kwargs: Additional arguments
488
489
Returns:
490
MaskRCNN: Mask R-CNN model
491
"""
492
493
def maskrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:
494
"""Mask R-CNN model with ResNet-50-FPN v2 backbone."""
495
```
496
497
#### Keypoint Detection Models
498
499
Models for human pose estimation and keypoint detection.
500
501
```python { .api }
502
class KeypointRCNN(torch.nn.Module):
503
"""
504
Keypoint R-CNN model for keypoint detection.
505
Extends Faster R-CNN with keypoint prediction branch.
506
507
Args:
508
backbone: Feature extraction backbone
509
num_classes: Number of classes (including background)
510
num_keypoints: Number of keypoints to detect
511
# ... (inherits all FasterRCNN parameters)
512
keypoint_roi_pool: RoI pooling layer for keypoints
513
keypoint_head: Keypoint head
514
keypoint_predictor: Keypoint predictor
515
"""
516
517
def keypointrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, num_keypoints=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> KeypointRCNN:
518
"""
519
Keypoint R-CNN model with ResNet-50-FPN backbone.
520
521
Args:
522
weights: Pre-trained weights to use
523
progress (bool): Show download progress bar
524
num_classes (int): Number of classes (overrides default)
525
num_keypoints (int): Number of keypoints (overrides default)
526
weights_backbone: Backbone weights to use
527
trainable_backbone_layers (int): Number of trainable backbone layers
528
**kwargs: Additional arguments
529
530
Returns:
531
KeypointRCNN: Keypoint R-CNN model
532
"""
533
```
534
535
#### Single-Shot Detectors
536
537
One-stage object detection models for faster inference.
538
539
```python { .api }
540
class RetinaNet(torch.nn.Module):
541
"""
542
RetinaNet model with focal loss for object detection.
543
544
Args:
545
backbone: Feature extraction backbone
546
num_classes: Number of classes
547
min_size: Minimum image size for rescaling
548
max_size: Maximum image size for rescaling
549
image_mean: Mean for image normalization
550
image_std: Standard deviation for image normalization
551
anchor_generator: Anchor generator
552
head: Detection head
553
score_thresh: Score threshold for inference
554
nms_thresh: NMS threshold
555
detections_per_img: Maximum detections per image
556
fg_iou_thresh: Foreground IoU threshold
557
bg_iou_thresh: Background IoU threshold
558
topk_candidates: Top-k candidates to keep
559
"""
560
561
def retinanet_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:
562
"""RetinaNet model with ResNet-50-FPN backbone."""
563
564
def retinanet_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:
565
"""RetinaNet model with ResNet-50-FPN v2 backbone."""
566
567
class SSD(torch.nn.Module):
568
"""Single Shot MultiBox Detector model."""
569
570
def ssd300_vgg16(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:
571
"""SSD300 model with VGG-16 backbone."""
572
573
def ssdlite320_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:
574
"""SSDLite320 model with MobileNetV3-Large backbone."""
575
576
class FCOS(torch.nn.Module):
577
"""FCOS (Fully Convolutional One-Stage) object detector."""
578
579
def fcos_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCOS:
580
"""FCOS model with ResNet-50-FPN backbone."""
581
```
582
583
### Semantic Segmentation Models
584
585
Pixel-level classification models for semantic segmentation.
586
587
```python { .api }
588
class FCN(torch.nn.Module):
589
"""
590
Fully Convolutional Network for semantic segmentation.
591
592
Args:
593
backbone: Feature extraction backbone
594
classifier: Classification head
595
aux_classifier: Auxiliary classification head
596
"""
597
598
def fcn_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:
599
"""FCN model with ResNet-50 backbone."""
600
601
def fcn_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:
602
"""FCN model with ResNet-101 backbone."""
603
604
class DeepLabV3(torch.nn.Module):
605
"""
606
DeepLabV3 model with atrous spatial pyramid pooling.
607
608
Args:
609
backbone: Feature extraction backbone
610
classifier: Classification head with ASPP
611
aux_classifier: Auxiliary classification head
612
"""
613
614
def deeplabv3_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
615
"""DeepLabV3 model with ResNet-50 backbone."""
616
617
def deeplabv3_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
618
"""DeepLabV3 model with ResNet-101 backbone."""
619
620
def deeplabv3_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
621
"""DeepLabV3 model with MobileNetV3-Large backbone."""
622
623
class LRASPP(torch.nn.Module):
624
"""
625
Lite R-ASPP model for fast semantic segmentation.
626
627
Args:
628
backbone: Feature extraction backbone
629
low_channels: Number of low-level feature channels
630
high_channels: Number of high-level feature channels
631
num_classes: Number of classes
632
inter_channels: Number of intermediate channels
633
"""
634
635
def lraspp_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> LRASPP:
636
"""LRASPP model with MobileNetV3-Large backbone."""
637
```
638
639
### Video Models
640
641
Models for video understanding and temporal analysis.
642
643
```python { .api }
644
class VideoResNet(torch.nn.Module):
645
"""
646
3D ResNet architecture for video classification.
647
648
Args:
649
block: 3D block type
650
conv_makers: Convolution configuration for each layer
651
layers: Number of blocks per layer
652
stem: Stem configuration
653
num_classes: Number of classes
654
zero_init_residual: Zero-initialize residual connections
655
"""
656
657
def r3d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
658
"""3D ResNet-18 for video classification."""
659
660
def mc3_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
661
"""Mixed Convolution 3D ResNet-18."""
662
663
def r2plus1d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
664
"""R(2+1)D ResNet-18 with factorized convolutions."""
665
666
class S3D(torch.nn.Module):
667
"""Separable 3D CNN architecture."""
668
669
def s3d(weights=None, progress: bool = True, **kwargs) -> S3D:
670
"""S3D model for video classification."""
671
672
class MViT(torch.nn.Module):
673
"""Multiscale Vision Transformer for video understanding."""
674
675
def mvit_v1_b(weights=None, progress: bool = True, **kwargs) -> MViT:
676
"""MViTv1-Base model."""
677
678
def mvit_v2_s(weights=None, progress: bool = True, **kwargs) -> MViT:
679
"""MViTv2-Small model."""
680
681
class SwinTransformer3D(torch.nn.Module):
682
"""3D Swin Transformer for video analysis."""
683
684
def swin3d_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
685
"""Swin3D Tiny model."""
686
687
def swin3d_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
688
"""Swin3D Small model."""
689
690
def swin3d_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
691
"""Swin3D Base model."""
692
```
693
694
### Optical Flow Models
695
696
Models for estimating optical flow between video frames.
697
698
```python { .api }
699
class RAFT(torch.nn.Module):
700
"""
701
RAFT (Recurrent All-Pairs Field Transforms) optical flow model.
702
703
Args:
704
feature_encoder: Feature extraction encoder
705
context_encoder: Context extraction encoder
706
correlation_block: Correlation block for feature matching
707
update_block: GRU-based update block
708
mask_predictor: Flow mask predictor
709
"""
710
711
def raft_large(weights=None, progress: bool = True, **kwargs) -> RAFT:
712
"""RAFT Large model for optical flow estimation."""
713
714
def raft_small(weights=None, progress: bool = True, **kwargs) -> RAFT:
715
"""RAFT Small model for optical flow estimation."""
716
```
717
718
### Quantized Models
719
720
Quantized versions of popular models for efficient inference.
721
722
```python { .api }
723
class QuantizableResNet(torch.nn.Module):
724
"""Quantizable ResNet architecture."""
725
726
# Quantized classification models
727
def resnet18(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
728
"""Quantized ResNet-18 model."""
729
730
def resnet50(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
731
"""Quantized ResNet-50 model."""
732
733
class QuantizableMobileNetV2(torch.nn.Module):
734
"""Quantizable MobileNetV2 architecture."""
735
736
def mobilenet_v2(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
737
"""Quantized MobileNetV2 model."""
738
739
class QuantizableMobileNetV3(torch.nn.Module):
740
"""Quantizable MobileNetV3 architecture."""
741
742
def mobilenet_v3_large(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
743
"""Quantized MobileNetV3-Large model."""
744
```
745
746
### Feature Extraction
747
748
Utilities for extracting intermediate features from pre-trained models.
749
750
```python { .api }
751
def create_feature_extractor(model: torch.nn.Module, return_nodes: dict, train_return_nodes=None, eval_return_nodes=None, tracer_kwargs=None, suppress_diff_warning: bool = False):
752
"""
753
Creates a feature extractor from any model.
754
755
Args:
756
model (torch.nn.Module): Model to extract features from
757
return_nodes (dict): Dict mapping node names to user-specified keys
758
train_return_nodes (dict, optional): Nodes to return during training
759
eval_return_nodes (dict, optional): Nodes to return during evaluation
760
tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer
761
suppress_diff_warning (bool): Suppress difference warning
762
763
Returns:
764
FeatureExtractor: Model wrapper that returns intermediate features
765
"""
766
767
def get_graph_node_names(model: torch.nn.Module, tracer_kwargs=None, suppress_diff_warning: bool = False):
768
"""
769
Gets graph node names for feature extraction.
770
771
Args:
772
model (torch.nn.Module): Model to analyze
773
tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer
774
suppress_diff_warning (bool): Suppress difference warning
775
776
Returns:
777
tuple: (train_nodes, eval_nodes) containing node names
778
"""
779
```
780
781
## Usage Examples
782
783
### Loading Pre-trained Models
784
785
```python
786
import torchvision.models as models
787
import torch
788
789
# Load a pre-trained ResNet-50
790
model = models.resnet50(weights='DEFAULT')
791
model.eval()
792
793
# Load model without weights
794
model = models.resnet50(weights=None)
795
796
# Load with specific weights
797
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
798
799
# Modify for different number of classes
800
model = models.resnet50(weights='DEFAULT')
801
num_classes = 10
802
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
803
```
804
805
### Object Detection
806
807
```python
808
import torchvision.models as models
809
import torchvision.transforms as transforms
810
from PIL import Image
811
812
# Load pre-trained Faster R-CNN
813
model = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
814
model.eval()
815
816
# Prepare image
817
transform = transforms.Compose([transforms.ToTensor()])
818
image = Image.open('image.jpg')
819
image_tensor = transform(image)
820
821
# Inference
822
with torch.no_grad():
823
predictions = model([image_tensor])
824
825
# Access results
826
boxes = predictions[0]['boxes']
827
scores = predictions[0]['scores']
828
labels = predictions[0]['labels']
829
```
830
831
### Feature Extraction
832
833
```python
834
import torchvision.models as models
835
from torchvision.models.feature_extraction import create_feature_extractor
836
837
# Load pre-trained model
838
model = models.resnet50(weights='DEFAULT')
839
840
# Create feature extractor
841
return_nodes = {
842
'layer1.2.conv3': 'layer1',
843
'layer2.3.conv3': 'layer2',
844
'layer3.5.conv3': 'layer3',
845
'layer4.2.conv3': 'layer4'
846
}
847
848
feature_extractor = create_feature_extractor(model, return_nodes)
849
850
# Extract features
851
with torch.no_grad():
852
features = feature_extractor(input_tensor)
853
854
# Access extracted features
855
layer1_features = features['layer1']
856
layer2_features = features['layer2']
857
```
858
859
### Video Classification
860
861
```python
862
import torchvision.models.video as video_models
863
import torch
864
865
# Load pre-trained video model
866
model = video_models.r3d_18(weights='DEFAULT')
867
model.eval()
868
869
# Prepare video tensor (batch_size, channels, frames, height, width)
870
video_tensor = torch.randn(1, 3, 16, 224, 224)
871
872
# Inference
873
with torch.no_grad():
874
predictions = model(video_tensor)
875
876
predicted_class = torch.argmax(predictions, dim=1)
877
```