Computer vision library for PyTorch with datasets, model architectures, and image/video transforms.
TorchVision provides pre-trained neural network models for various computer vision tasks including image classification, object detection, instance segmentation, semantic segmentation, keypoint detection, and video analysis. All models support both training and evaluation modes with optional pre-trained weights.
High-level API for discovering and loading models with configuration.
def get_model(name: str, **config) -> torch.nn.Module:
"""
Get model by name with configuration.
Args:
name (str): Model name
**config: Model-specific configuration parameters
Returns:
torch.nn.Module: Instantiated model
"""
def get_model_builder(name: str):
"""
Get model builder function by name.
Args:
name (str): Model name
Returns:
Callable: Model builder function
"""
def get_model_weights(name: str):
"""
Get available weights for a model.
Args:
name (str): Model name
Returns:
Dict of available weights
"""
def get_weight(name: str):
"""
Get specific weight by name.
Args:
name (str): Weight name
Returns:
Weight object
"""
def list_models() -> list[str]:
"""
List all available models.
Returns:
list[str]: List of model names
"""
class Weights:
"""Dataclass for model weights metadata."""
url: str
transforms: callable
meta: dict
class WeightsEnum:
"""Enum base class for model weights."""Deep residual networks with skip connections for image classification.
class ResNet(torch.nn.Module):
"""
ResNet architecture implementation.
Args:
block: Block type (BasicBlock or Bottleneck)
layers (list): Number of blocks per layer
num_classes (int): Number of classes for classification
zero_init_residual (bool): Zero-initialize residual connections
groups (int): Number of groups for grouped convolution
width_per_group (int): Width per group for grouped convolution
replace_stride_with_dilation (list): Replace stride with dilation
norm_layer: Normalization layer
"""
def resnet18(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""
ResNet-18 model.
Args:
weights: Pre-trained weights to use (None, 'DEFAULT', or specific weights)
progress (bool): Show download progress bar
**kwargs: Additional arguments passed to ResNet
Returns:
ResNet: ResNet-18 model
"""
def resnet34(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNet-34 model."""
def resnet50(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNet-50 model."""
def resnet101(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNet-101 model."""
def resnet152(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNet-152 model."""
def resnext50_32x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNeXt-50 32x4d model with grouped convolutions."""
def resnext101_32x8d(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNeXt-101 32x8d model with grouped convolutions."""
def resnext101_64x4d(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""ResNeXt-101 64x4d model with grouped convolutions."""
def wide_resnet50_2(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""Wide ResNet-50-2 model with wider channels."""
def wide_resnet101_2(weights=None, progress: bool = True, **kwargs) -> ResNet:
"""Wide ResNet-101-2 model with wider channels."""Transformer-based models for image classification using patch embeddings.
class VisionTransformer(torch.nn.Module):
"""
Vision Transformer architecture.
Args:
image_size (int): Input image size
patch_size (int): Size of image patches
num_layers (int): Number of transformer layers
num_heads (int): Number of attention heads
hidden_dim (int): Hidden dimension size
mlp_dim (int): MLP dimension size
dropout (float): Dropout rate
attention_dropout (float): Attention dropout rate
num_classes (int): Number of classes
representation_size: Optional representation layer size
norm_layer: Normalization layer
conv_stem_configs: Optional convolutional stem configuration
"""
def vit_b_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
"""
ViT-Base/16 model with 16x16 patches.
Args:
weights: Pre-trained weights to use
progress (bool): Show download progress bar
**kwargs: Additional arguments
Returns:
VisionTransformer: ViT-Base/16 model
"""
def vit_b_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
"""ViT-Base/32 model with 32x32 patches."""
def vit_l_16(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
"""ViT-Large/16 model with 16x16 patches."""
def vit_l_32(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
"""ViT-Large/32 model with 32x32 patches."""
def vit_h_14(weights=None, progress: bool = True, **kwargs) -> VisionTransformer:
"""ViT-Huge/14 model with 14x14 patches."""Efficient convolutional networks optimized for accuracy and efficiency.
class EfficientNet(torch.nn.Module):
"""
EfficientNet architecture with compound scaling.
Args:
inverted_residual_setting: Network structure configuration
dropout (float): Dropout rate
stochastic_depth_prob (float): Stochastic depth probability
num_classes (int): Number of classes
norm_layer: Normalization layer
last_channel: Optional last channel override
"""
def efficientnet_b0(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B0 model."""
def efficientnet_b1(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B1 model."""
def efficientnet_b2(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B2 model."""
def efficientnet_b3(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B3 model."""
def efficientnet_b4(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B4 model."""
def efficientnet_b5(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B5 model."""
def efficientnet_b6(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B6 model."""
def efficientnet_b7(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNet-B7 model."""
def efficientnet_v2_s(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNetV2-Small model with improved training and scaling."""
def efficientnet_v2_m(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNetV2-Medium model."""
def efficientnet_v2_l(weights=None, progress: bool = True, **kwargs) -> EfficientNet:
"""EfficientNetV2-Large model."""Lightweight models designed for mobile and embedded devices.
class MobileNetV2(torch.nn.Module):
"""
MobileNetV2 architecture with inverted residuals and linear bottlenecks.
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier for channels
inverted_residual_setting: Optional network structure override
round_nearest (int): Round channels to nearest multiple
block: Block type for inverted residuals
norm_layer: Normalization layer
dropout (float): Dropout rate
"""
class MobileNetV3(torch.nn.Module):
"""
MobileNetV3 architecture with squeeze-and-excitation modules.
Args:
inverted_residual_setting: Network structure configuration
last_channel (int): Number of channels in final layer
num_classes (int): Number of classes
block: Block type for inverted residuals
norm_layer: Normalization layer
dropout (float): Dropout rate
"""
def mobilenet_v2(weights=None, progress: bool = True, **kwargs) -> MobileNetV2:
"""
MobileNetV2 model.
Args:
weights: Pre-trained weights to use
progress (bool): Show download progress bar
**kwargs: Additional arguments
Returns:
MobileNetV2: MobileNetV2 model
"""
def mobilenet_v3_large(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:
"""MobileNetV3-Large model."""
def mobilenet_v3_small(weights=None, progress: bool = True, **kwargs) -> MobileNetV3:
"""MobileNetV3-Small model."""Additional popular classification architectures.
class AlexNet(torch.nn.Module):
"""AlexNet architecture for image classification."""
def alexnet(weights=None, progress: bool = True, **kwargs) -> AlexNet:
"""AlexNet model."""
class VGG(torch.nn.Module):
"""VGG architecture with customizable depth."""
def vgg11(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 11-layer model."""
def vgg11_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 11-layer model with batch normalization."""
def vgg13(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 13-layer model."""
def vgg13_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 13-layer model with batch normalization."""
def vgg16(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 16-layer model."""
def vgg16_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 16-layer model with batch normalization."""
def vgg19(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 19-layer model."""
def vgg19_bn(weights=None, progress: bool = True, **kwargs) -> VGG:
"""VGG 19-layer model with batch normalization."""
class DenseNet(torch.nn.Module):
"""DenseNet architecture with dense connections."""
def densenet121(weights=None, progress: bool = True, **kwargs) -> DenseNet:
"""DenseNet-121 model."""
def densenet161(weights=None, progress: bool = True, **kwargs) -> DenseNet:
"""DenseNet-161 model."""
def densenet169(weights=None, progress: bool = True, **kwargs) -> DenseNet:
"""DenseNet-169 model."""
def densenet201(weights=None, progress: bool = True, **kwargs) -> DenseNet:
"""DenseNet-201 model."""
class Inception3(torch.nn.Module):
"""Inception v3 architecture."""
def inception_v3(weights=None, progress: bool = True, **kwargs) -> Inception3:
"""Inception v3 model."""
class GoogLeNet(torch.nn.Module):
"""GoogLeNet architecture with inception modules."""
def googlenet(weights=None, progress: bool = True, **kwargs) -> GoogLeNet:
"""GoogLeNet model."""
class ConvNeXt(torch.nn.Module):
"""ConvNeXt architecture with modernized ResNet design."""
def convnext_tiny(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
"""ConvNeXt Tiny model."""
def convnext_small(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
"""ConvNeXt Small model."""
def convnext_base(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
"""ConvNeXt Base model."""
def convnext_large(weights=None, progress: bool = True, **kwargs) -> ConvNeXt:
"""ConvNeXt Large model."""
class SwinTransformer(torch.nn.Module):
"""Swin Transformer with hierarchical feature maps."""
def swin_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
"""Swin Transformer Tiny model."""
def swin_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
"""Swin Transformer Small model."""
def swin_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer:
"""Swin Transformer Base model."""
class MaxVit(torch.nn.Module):
"""MaxVit architecture combining convolution and attention."""
def maxvit_t(weights=None, progress: bool = True, **kwargs) -> MaxVit:
"""MaxVit Tiny model."""Region-based convolutional neural networks for object detection.
class FasterRCNN(torch.nn.Module):
"""
Faster R-CNN model for object detection.
Args:
backbone: Feature extraction backbone
num_classes: Number of classes (including background)
min_size: Minimum image size for rescaling
max_size: Maximum image size for rescaling
image_mean: Mean for image normalization
image_std: Standard deviation for image normalization
rpn_anchor_generator: RPN anchor generator
rpn_head: RPN head
rpn_pre_nms_top_n_train: RPN pre-NMS top-k (training)
rpn_pre_nms_top_n_test: RPN pre-NMS top-k (testing)
rpn_post_nms_top_n_train: RPN post-NMS top-k (training)
rpn_post_nms_top_n_test: RPN post-NMS top-k (testing)
rpn_nms_thresh: RPN NMS threshold
rpn_fg_iou_thresh: RPN foreground IoU threshold
rpn_bg_iou_thresh: RPN background IoU threshold
rpn_batch_size_per_image: RPN batch size per image
rpn_positive_fraction: RPN positive fraction
box_roi_pool: RoI pooling layer for boxes
box_head: Box head
box_predictor: Box predictor
box_score_thresh: Box score threshold for inference
box_nms_thresh: Box NMS threshold
box_detections_per_img: Maximum detections per image
box_fg_iou_thresh: Box foreground IoU threshold
box_bg_iou_thresh: Box background IoU threshold
box_batch_size_per_image: Box batch size per image
box_positive_fraction: Box positive fraction
bbox_reg_weights: Bounding box regression weights
"""
def fasterrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
"""
Faster R-CNN model with ResNet-50-FPN backbone.
Args:
weights: Pre-trained weights to use
progress (bool): Show download progress bar
num_classes (int): Number of classes (overrides default)
weights_backbone: Backbone weights to use
trainable_backbone_layers (int): Number of trainable backbone layers
**kwargs: Additional arguments
Returns:
FasterRCNN: Faster R-CNN model
"""
def fasterrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
"""Faster R-CNN model with ResNet-50-FPN v2 backbone."""
def fasterrcnn_mobilenet_v3_large_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
"""Faster R-CNN model with MobileNetV3-Large-FPN backbone."""
def fasterrcnn_mobilenet_v3_large_320_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FasterRCNN:
"""Faster R-CNN model with MobileNetV3-Large-320-FPN backbone."""Models for simultaneous object detection and instance segmentation.
class MaskRCNN(torch.nn.Module):
"""
Mask R-CNN model for instance segmentation.
Extends Faster R-CNN with mask prediction branch.
Args:
backbone: Feature extraction backbone
num_classes: Number of classes (including background)
# ... (inherits all FasterRCNN parameters)
mask_roi_pool: RoI pooling layer for masks
mask_head: Mask head
mask_predictor: Mask predictor
"""
def maskrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:
"""
Mask R-CNN model with ResNet-50-FPN backbone.
Args:
weights: Pre-trained weights to use
progress (bool): Show download progress bar
num_classes (int): Number of classes (overrides default)
weights_backbone: Backbone weights to use
trainable_backbone_layers (int): Number of trainable backbone layers
**kwargs: Additional arguments
Returns:
MaskRCNN: Mask R-CNN model
"""
def maskrcnn_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> MaskRCNN:
"""Mask R-CNN model with ResNet-50-FPN v2 backbone."""Models for human pose estimation and keypoint detection.
class KeypointRCNN(torch.nn.Module):
"""
Keypoint R-CNN model for keypoint detection.
Extends Faster R-CNN with keypoint prediction branch.
Args:
backbone: Feature extraction backbone
num_classes: Number of classes (including background)
num_keypoints: Number of keypoints to detect
# ... (inherits all FasterRCNN parameters)
keypoint_roi_pool: RoI pooling layer for keypoints
keypoint_head: Keypoint head
keypoint_predictor: Keypoint predictor
"""
def keypointrcnn_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, num_keypoints=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> KeypointRCNN:
"""
Keypoint R-CNN model with ResNet-50-FPN backbone.
Args:
weights: Pre-trained weights to use
progress (bool): Show download progress bar
num_classes (int): Number of classes (overrides default)
num_keypoints (int): Number of keypoints (overrides default)
weights_backbone: Backbone weights to use
trainable_backbone_layers (int): Number of trainable backbone layers
**kwargs: Additional arguments
Returns:
KeypointRCNN: Keypoint R-CNN model
"""One-stage object detection models for faster inference.
class RetinaNet(torch.nn.Module):
"""
RetinaNet model with focal loss for object detection.
Args:
backbone: Feature extraction backbone
num_classes: Number of classes
min_size: Minimum image size for rescaling
max_size: Maximum image size for rescaling
image_mean: Mean for image normalization
image_std: Standard deviation for image normalization
anchor_generator: Anchor generator
head: Detection head
score_thresh: Score threshold for inference
nms_thresh: NMS threshold
detections_per_img: Maximum detections per image
fg_iou_thresh: Foreground IoU threshold
bg_iou_thresh: Background IoU threshold
topk_candidates: Top-k candidates to keep
"""
def retinanet_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:
"""RetinaNet model with ResNet-50-FPN backbone."""
def retinanet_resnet50_fpn_v2(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> RetinaNet:
"""RetinaNet model with ResNet-50-FPN v2 backbone."""
class SSD(torch.nn.Module):
"""Single Shot MultiBox Detector model."""
def ssd300_vgg16(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:
"""SSD300 model with VGG-16 backbone."""
def ssdlite320_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> SSD:
"""SSDLite320 model with MobileNetV3-Large backbone."""
class FCOS(torch.nn.Module):
"""FCOS (Fully Convolutional One-Stage) object detector."""
def fcos_resnet50_fpn(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCOS:
"""FCOS model with ResNet-50-FPN backbone."""Pixel-level classification models for semantic segmentation.
class FCN(torch.nn.Module):
"""
Fully Convolutional Network for semantic segmentation.
Args:
backbone: Feature extraction backbone
classifier: Classification head
aux_classifier: Auxiliary classification head
"""
def fcn_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:
"""FCN model with ResNet-50 backbone."""
def fcn_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> FCN:
"""FCN model with ResNet-101 backbone."""
class DeepLabV3(torch.nn.Module):
"""
DeepLabV3 model with atrous spatial pyramid pooling.
Args:
backbone: Feature extraction backbone
classifier: Classification head with ASPP
aux_classifier: Auxiliary classification head
"""
def deeplabv3_resnet50(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
"""DeepLabV3 model with ResNet-50 backbone."""
def deeplabv3_resnet101(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
"""DeepLabV3 model with ResNet-101 backbone."""
def deeplabv3_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> DeepLabV3:
"""DeepLabV3 model with MobileNetV3-Large backbone."""
class LRASPP(torch.nn.Module):
"""
Lite R-ASPP model for fast semantic segmentation.
Args:
backbone: Feature extraction backbone
low_channels: Number of low-level feature channels
high_channels: Number of high-level feature channels
num_classes: Number of classes
inter_channels: Number of intermediate channels
"""
def lraspp_mobilenet_v3_large(weights=None, progress: bool = True, num_classes=None, weights_backbone=None, trainable_backbone_layers=None, **kwargs) -> LRASPP:
"""LRASPP model with MobileNetV3-Large backbone."""Models for video understanding and temporal analysis.
class VideoResNet(torch.nn.Module):
"""
3D ResNet architecture for video classification.
Args:
block: 3D block type
conv_makers: Convolution configuration for each layer
layers: Number of blocks per layer
stem: Stem configuration
num_classes: Number of classes
zero_init_residual: Zero-initialize residual connections
"""
def r3d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
"""3D ResNet-18 for video classification."""
def mc3_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
"""Mixed Convolution 3D ResNet-18."""
def r2plus1d_18(weights=None, progress: bool = True, **kwargs) -> VideoResNet:
"""R(2+1)D ResNet-18 with factorized convolutions."""
class S3D(torch.nn.Module):
"""Separable 3D CNN architecture."""
def s3d(weights=None, progress: bool = True, **kwargs) -> S3D:
"""S3D model for video classification."""
class MViT(torch.nn.Module):
"""Multiscale Vision Transformer for video understanding."""
def mvit_v1_b(weights=None, progress: bool = True, **kwargs) -> MViT:
"""MViTv1-Base model."""
def mvit_v2_s(weights=None, progress: bool = True, **kwargs) -> MViT:
"""MViTv2-Small model."""
class SwinTransformer3D(torch.nn.Module):
"""3D Swin Transformer for video analysis."""
def swin3d_t(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
"""Swin3D Tiny model."""
def swin3d_s(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
"""Swin3D Small model."""
def swin3d_b(weights=None, progress: bool = True, **kwargs) -> SwinTransformer3D:
"""Swin3D Base model."""Models for estimating optical flow between video frames.
class RAFT(torch.nn.Module):
"""
RAFT (Recurrent All-Pairs Field Transforms) optical flow model.
Args:
feature_encoder: Feature extraction encoder
context_encoder: Context extraction encoder
correlation_block: Correlation block for feature matching
update_block: GRU-based update block
mask_predictor: Flow mask predictor
"""
def raft_large(weights=None, progress: bool = True, **kwargs) -> RAFT:
"""RAFT Large model for optical flow estimation."""
def raft_small(weights=None, progress: bool = True, **kwargs) -> RAFT:
"""RAFT Small model for optical flow estimation."""Quantized versions of popular models for efficient inference.
class QuantizableResNet(torch.nn.Module):
"""Quantizable ResNet architecture."""
# Quantized classification models
def resnet18(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
"""Quantized ResNet-18 model."""
def resnet50(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
"""Quantized ResNet-50 model."""
class QuantizableMobileNetV2(torch.nn.Module):
"""Quantizable MobileNetV2 architecture."""
def mobilenet_v2(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
"""Quantized MobileNetV2 model."""
class QuantizableMobileNetV3(torch.nn.Module):
"""Quantizable MobileNetV3 architecture."""
def mobilenet_v3_large(weights=None, progress: bool = True, quantize: bool = False, **kwargs):
"""Quantized MobileNetV3-Large model."""Utilities for extracting intermediate features from pre-trained models.
def create_feature_extractor(model: torch.nn.Module, return_nodes: dict, train_return_nodes=None, eval_return_nodes=None, tracer_kwargs=None, suppress_diff_warning: bool = False):
"""
Creates a feature extractor from any model.
Args:
model (torch.nn.Module): Model to extract features from
return_nodes (dict): Dict mapping node names to user-specified keys
train_return_nodes (dict, optional): Nodes to return during training
eval_return_nodes (dict, optional): Nodes to return during evaluation
tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer
suppress_diff_warning (bool): Suppress difference warning
Returns:
FeatureExtractor: Model wrapper that returns intermediate features
"""
def get_graph_node_names(model: torch.nn.Module, tracer_kwargs=None, suppress_diff_warning: bool = False):
"""
Gets graph node names for feature extraction.
Args:
model (torch.nn.Module): Model to analyze
tracer_kwargs (dict, optional): Keyword arguments for symbolic tracer
suppress_diff_warning (bool): Suppress difference warning
Returns:
tuple: (train_nodes, eval_nodes) containing node names
"""import torchvision.models as models
import torch
# Load a pre-trained ResNet-50
model = models.resnet50(weights='DEFAULT')
model.eval()
# Load model without weights
model = models.resnet50(weights=None)
# Load with specific weights
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Modify for different number of classes
model = models.resnet50(weights='DEFAULT')
num_classes = 10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# Load pre-trained Faster R-CNN
model = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
model.eval()
# Prepare image
transform = transforms.Compose([transforms.ToTensor()])
image = Image.open('image.jpg')
image_tensor = transform(image)
# Inference
with torch.no_grad():
predictions = model([image_tensor])
# Access results
boxes = predictions[0]['boxes']
scores = predictions[0]['scores']
labels = predictions[0]['labels']import torchvision.models as models
from torchvision.models.feature_extraction import create_feature_extractor
# Load pre-trained model
model = models.resnet50(weights='DEFAULT')
# Create feature extractor
return_nodes = {
'layer1.2.conv3': 'layer1',
'layer2.3.conv3': 'layer2',
'layer3.5.conv3': 'layer3',
'layer4.2.conv3': 'layer4'
}
feature_extractor = create_feature_extractor(model, return_nodes)
# Extract features
with torch.no_grad():
features = feature_extractor(input_tensor)
# Access extracted features
layer1_features = features['layer1']
layer2_features = features['layer2']import torchvision.models.video as video_models
import torch
# Load pre-trained video model
model = video_models.r3d_18(weights='DEFAULT')
model.eval()
# Prepare video tensor (batch_size, channels, frames, height, width)
video_tensor = torch.randn(1, 3, 16, 224, 224)
# Inference
with torch.no_grad():
predictions = model(video_tensor)
predicted_class = torch.argmax(predictions, dim=1)Install with Tessl CLI
npx tessl i tessl/pypi-torchvision