Computer vision library for PyTorch with datasets, model architectures, and image/video transforms.
TorchVision TV Tensors provide enhanced tensor types that preserve metadata and semantics through transformations. These specialized tensors enable transforms to handle multiple data types (images, videos, bounding boxes, masks, keypoints) consistently while maintaining their specific properties and constraints.
Foundation class for all TorchVision tensor types with enhanced metadata support.
class TVTensor(torch.Tensor):
"""
Base class for all torchvision tensor types.
Extends torch.Tensor with metadata preservation through transformations.
Provides automatic wrapping and unwrapping of tensor operations while
maintaining type-specific metadata and constraints.
"""
def __new__(cls, data, **kwargs): ...
def wrap_like(self, other, **kwargs):
"""Wrap tensor with same type and metadata as another TVTensor."""Enhanced image tensors that preserve image semantics and properties.
class Image(TVTensor):
"""
Image tensor type with preserved image semantics.
Inherits from torch.Tensor and maintains image-specific properties
through transformations. Ensures operations maintain image constraints
like channel ordering and value ranges.
Args:
data: Image data as tensor, PIL Image, or numpy array
dtype: Data type for the tensor (default: inferred)
device: Device to place tensor on
requires_grad: Whether tensor requires gradients
Shape:
- (C, H, W): Single image with C channels, H height, W width
- (N, C, H, W): Batch of N images
"""
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
@property
def spatial_size(self) -> tuple:
"""Get spatial dimensions (height, width)."""
@property
def num_channels(self) -> int:
"""Get number of channels."""
@property
def image_size(self) -> tuple:
"""Get image size as (height, width)."""Specialized tensors for temporal video data with frame sequence handling.
class Video(TVTensor):
"""
Video tensor type for temporal data sequences.
Handles temporal dimension and preserves video-specific properties
through transformations. Maintains frame relationships and temporal
consistency.
Args:
data: Video data as tensor or array
dtype: Data type for the tensor
device: Device to place tensor on
requires_grad: Whether tensor requires gradients
Shape:
- (T, C, H, W): Single video with T frames, C channels, H height, W width
- (N, T, C, H, W): Batch of N videos
"""
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
@property
def num_frames(self) -> int:
"""Get number of frames."""
@property
def frame_size(self) -> tuple:
"""Get frame dimensions (height, width)."""
@property
def temporal_size(self) -> int:
"""Get temporal dimension size."""Bounding box tensors with format awareness and coordinate system handling.
class BoundingBoxes(TVTensor):
"""
Bounding box tensor with format and canvas size metadata.
Handles different bounding box formats and maintains coordinate
system constraints. Automatically handles transformations while
preserving box validity and format consistency.
Args:
data: Bounding box coordinates as tensor
format: Box format ('XYXY', 'XYWH', 'CXCYWH')
canvas_size: Image dimensions as (height, width)
dtype: Data type for coordinates
device: Device to place tensor on
requires_grad: Whether tensor requires gradients
Shape:
- (4,): Single bounding box [x1, y1, x2, y2] or format-specific
- (N, 4): N bounding boxes
"""
def __new__(cls, data, *, format: str, canvas_size: tuple, dtype=None, device=None, requires_grad=None): ...
@property
def format(self) -> str:
"""Get bounding box format ('XYXY', 'XYWH', 'CXCYWH')."""
@property
def canvas_size(self) -> tuple:
"""Get canvas dimensions (height, width)."""
@property
def clamping_mode(self) -> str:
"""Get clamping mode for out-of-bounds boxes."""
def clamp(self) -> 'BoundingBoxes':
"""Clamp boxes to canvas boundaries."""
def convert_format(self, format: str) -> 'BoundingBoxes':
"""Convert to different bounding box format."""
class BoundingBoxFormat:
"""Bounding box format constants and utilities."""
XYXY: str = "XYXY" # [x_min, y_min, x_max, y_max]
XYWH: str = "XYWH" # [x_min, y_min, width, height]
CXCYWH: str = "CXCYWH" # [center_x, center_y, width, height]
def is_rotated_bounding_format(format: str) -> bool:
"""
Check if bounding box format supports rotated boxes.
Args:
format (str): Bounding box format string
Returns:
bool: True if format supports rotation
"""Segmentation mask tensors for pixel-level annotations and predictions.
class Mask(TVTensor):
"""
Segmentation mask tensor type for pixel-level annotations.
Handles boolean or integer masks while preserving spatial
relationships and mask-specific properties through transformations.
Args:
data: Mask data as tensor or array (boolean or integer)
dtype: Data type (typically torch.bool or torch.uint8)
device: Device to place tensor on
requires_grad: Whether tensor requires gradients
Shape:
- (H, W): Single binary mask
- (N, H, W): N masks (e.g., instance segmentation)
- (C, H, W): Multi-class segmentation mask
"""
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
@property
def spatial_size(self) -> tuple:
"""Get spatial dimensions (height, width)."""
@property
def num_masks(self) -> int:
"""Get number of individual masks."""
def to_binary(self) -> 'Mask':
"""Convert to binary mask format."""Keypoint tensors for pose estimation and landmark detection tasks.
class KeyPoints(TVTensor):
"""
Keypoint tensor with canvas size and connectivity information.
Handles keypoint coordinates with visibility information and
maintains spatial constraints through transformations.
Args:
data: Keypoint coordinates and visibility
canvas_size: Image dimensions as (height, width)
dtype: Data type for coordinates
device: Device to place tensor on
requires_grad: Whether tensor requires gradients
Shape:
- (K, 2): K keypoints with [x, y] coordinates
- (K, 3): K keypoints with [x, y, visibility]
- (N, K, 2): N instances with K keypoints each
- (N, K, 3): N instances with K keypoints and visibility
"""
def __new__(cls, data, *, canvas_size: tuple, dtype=None, device=None, requires_grad=None): ...
@property
def canvas_size(self) -> tuple:
"""Get canvas dimensions (height, width)."""
@property
def num_keypoints(self) -> int:
"""Get number of keypoints per instance."""
@property
def num_instances(self) -> int:
"""Get number of instances."""
def has_visibility(self) -> bool:
"""Check if keypoints have visibility information."""
def clamp(self) -> 'KeyPoints':
"""Clamp keypoints to canvas boundaries."""Utilities for working with TV tensors and managing type consistency.
def wrap(wrappee, *, like, **kwargs):
"""
Wrap tensor as same type as reference tensor.
Args:
wrappee: Tensor to wrap
like: Reference TV tensor to match type and metadata
**kwargs: Additional arguments for tensor creation
Returns:
TV tensor of same type as 'like' parameter
"""
def set_return_type(return_type: str):
"""
Set return type for tensor operations.
Args:
return_type (str): Type to return from operations
('TVTensor', 'Tensor', or 'auto')
"""from torchvision.tv_tensors import Image
import torch
from PIL import Image as PILImage
# Create Image tensor from different sources
pil_image = PILImage.open('image.jpg')
image_tensor = Image(pil_image)
print(f"Image shape: {image_tensor.shape}")
print(f"Image size: {image_tensor.image_size}")
print(f"Channels: {image_tensor.num_channels}")
# Create from tensor data
tensor_data = torch.randint(0, 256, (3, 224, 224), dtype=torch.uint8)
image_tensor = Image(tensor_data)
# Image tensors preserve type through operations
scaled_image = image_tensor * 0.5
print(f"Scaled image type: {type(scaled_image)}") # Still Image type
# Batch of images
batch_data = torch.randint(0, 256, (8, 3, 224, 224), dtype=torch.uint8)
batch_images = Image(batch_data)
print(f"Batch shape: {batch_images.shape}")from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat
import torch
# Create bounding boxes in XYXY format
boxes_data = torch.tensor([
[10, 20, 100, 150],
[50, 30, 200, 180],
[75, 60, 150, 200]
], dtype=torch.float)
canvas_size = (240, 320) # (height, width)
boxes = BoundingBoxes(
boxes_data,
format=BoundingBoxFormat.XYXY,
canvas_size=canvas_size
)
print(f"Boxes format: {boxes.format}")
print(f"Canvas size: {boxes.canvas_size}")
# Convert between formats
boxes_xywh = boxes.convert_format(BoundingBoxFormat.XYWH)
print(f"Converted format: {boxes_xywh.format}")
print(f"XYWH boxes: {boxes_xywh}")
# Clamp boxes to canvas boundaries
# (useful after transformations that might move boxes out of bounds)
clamped_boxes = boxes.clamp()
# Create boxes in different format
center_boxes = BoundingBoxes(
torch.tensor([[50, 60, 80, 100]]), # [cx, cy, w, h]
format=BoundingBoxFormat.CXCYWH,
canvas_size=canvas_size
)from torchvision.tv_tensors import Mask
import torch
# Binary segmentation mask
binary_mask_data = torch.zeros(240, 320, dtype=torch.bool)
binary_mask_data[50:150, 60:160] = True # Create rectangular mask
binary_mask = Mask(binary_mask_data)
print(f"Binary mask shape: {binary_mask.shape}")
print(f"Spatial size: {binary_mask.spatial_size}")
# Multi-instance masks (e.g., for instance segmentation)
num_instances = 3
instance_masks_data = torch.zeros(num_instances, 240, 320, dtype=torch.bool)
instance_masks_data[0, 20:80, 30:90] = True # Instance 1
instance_masks_data[1, 100:160, 150:210] = True # Instance 2
instance_masks_data[2, 180:220, 50:150] = True # Instance 3
instance_masks = Mask(instance_masks_data)
print(f"Instance masks shape: {instance_masks.shape}")
print(f"Number of masks: {instance_masks.num_masks}")
# Integer-valued masks (e.g., semantic segmentation)
semantic_mask_data = torch.zeros(240, 320, dtype=torch.uint8)
semantic_mask_data[50:100, 50:100] = 1 # Class 1
semantic_mask_data[150:200, 150:200] = 2 # Class 2
semantic_mask = Mask(semantic_mask_data)from torchvision.tv_tensors import KeyPoints
import torch
# COCO-style human pose keypoints (17 keypoints)
# Format: [x, y, visibility] where visibility: 0=not labeled, 1=labeled but not visible, 2=labeled and visible
keypoints_data = torch.tensor([
[160, 80, 2], # nose
[155, 85, 2], # left_eye
[165, 85, 2], # right_eye
[150, 90, 2], # left_ear
[170, 90, 2], # right_ear
[140, 120, 2], # left_shoulder
[180, 120, 2], # right_shoulder
[130, 150, 2], # left_elbow
[190, 150, 1], # right_elbow (labeled but occluded)
[125, 180, 0], # left_wrist (not labeled)
[195, 180, 2], # right_wrist
[150, 200, 2], # left_hip
[170, 200, 2], # right_hip
[145, 240, 2], # left_knee
[175, 240, 2], # right_knee
[140, 280, 2], # left_ankle
[180, 280, 2], # right_ankle
], dtype=torch.float)
canvas_size = (320, 240)
keypoints = KeyPoints(keypoints_data, canvas_size=canvas_size)
print(f"Keypoints shape: {keypoints.shape}")
print(f"Number of keypoints: {keypoints.num_keypoints}")
print(f"Has visibility: {keypoints.has_visibility()}")
print(f"Canvas size: {keypoints.canvas_size}")
# Multiple person keypoints
batch_keypoints_data = torch.randn(5, 17, 3) # 5 people, 17 keypoints, [x,y,vis]
batch_keypoints = KeyPoints(batch_keypoints_data, canvas_size=canvas_size)
print(f"Batch keypoints shape: {batch_keypoints.shape}")
print(f"Number of instances: {batch_keypoints.num_instances}")
# Clamp keypoints to image boundaries
clamped_keypoints = keypoints.clamp()from torchvision.tv_tensors import Video
import torch
# Create video tensor (16 frames, 3 channels, 224x224)
video_data = torch.randint(0, 256, (16, 3, 224, 224), dtype=torch.uint8)
video = Video(video_data)
print(f"Video shape: {video.shape}")
print(f"Number of frames: {video.num_frames}")
print(f"Frame size: {video.frame_size}")
print(f"Temporal size: {video.temporal_size}")
# Batch of videos
batch_video_data = torch.randint(0, 256, (4, 16, 3, 224, 224), dtype=torch.uint8)
batch_videos = Video(batch_video_data)
print(f"Batch videos shape: {batch_videos.shape}")
# Video tensors maintain type through operations
downsampled_video = video[:8] # Take first 8 frames
print(f"Downsampled video type: {type(downsampled_video)}")from torchvision.tv_tensors import Image, BoundingBoxes, BoundingBoxFormat
from torchvision.transforms import v2
import torch
# Create sample data
image = Image(torch.randint(0, 256, (3, 480, 640), dtype=torch.uint8))
boxes = BoundingBoxes(
torch.tensor([[50, 50, 200, 200], [300, 150, 450, 350]]),
format=BoundingBoxFormat.XYXY,
canvas_size=(480, 640)
)
print("Before transform:")
print(f"Image shape: {image.shape}")
print(f"Boxes: {boxes}")
print(f"Boxes format: {boxes.format}")
# Apply transforms that work with multiple tensor types
transform = v2.Compose([
v2.RandomHorizontalFlip(p=1.0), # Always flip for demonstration
v2.Resize((224, 224)),
v2.ToDtype(torch.float32, scale=True)
])
# Transform both image and boxes together
transformed_image, transformed_boxes = transform(image, boxes)
print("\nAfter transform:")
print(f"Image shape: {transformed_image.shape}")
print(f"Image type: {type(transformed_image)}")
print(f"Boxes: {transformed_boxes}")
print(f"Boxes type: {type(transformed_boxes)}")
print(f"Boxes format: {transformed_boxes.format}")
print(f"New canvas size: {transformed_boxes.canvas_size}")from torchvision.tv_tensors import Image, wrap
import torch
def custom_image_operation(img):
"""
Custom operation that preserves TV tensor type.
"""
# Perform some operation on the underlying tensor
processed = img * 0.8 + 0.1 # Adjust brightness
# Wrap result to maintain TV tensor type and metadata
return wrap(processed, like=img)
def batch_process_images(images):
"""
Process batch of images while maintaining types.
"""
results = []
for img in images:
processed = custom_image_operation(img)
results.append(processed)
return torch.stack(results)
# Test custom operations
image = Image(torch.rand(3, 224, 224))
processed_image = custom_image_operation(image)
print(f"Original type: {type(image)}")
print(f"Processed type: {type(processed_image)}")
# Works with batches too
batch_images = [Image(torch.rand(3, 224, 224)) for _ in range(4)]
batch_result = batch_process_images(batch_images)
print(f"Batch result shape: {batch_result.shape}")
print(f"Batch result type: {type(batch_result)}")from torchvision.tv_tensors import Image, BoundingBoxes, Mask, BoundingBoxFormat
from torchvision.transforms import v2
import torch
def detection_pipeline():
"""
Example object detection data pipeline using TV tensors.
"""
# Simulate loading detection data
image = Image(torch.randint(0, 256, (3, 480, 640), dtype=torch.uint8))
boxes = BoundingBoxes(
torch.tensor([[100, 100, 300, 250], [200, 150, 400, 350]]),
format=BoundingBoxFormat.XYXY,
canvas_size=(480, 640)
)
# Instance masks for each detection
masks_data = torch.zeros(2, 480, 640, dtype=torch.bool)
masks_data[0, 100:250, 100:300] = True
masks_data[1, 150:350, 200:400] = True
masks = Mask(masks_data)
# Labels for each detection
labels = torch.tensor([1, 2]) # Class IDs
print("Original data:")
print(f"Image: {image.shape}, {type(image)}")
print(f"Boxes: {boxes.shape}, {type(boxes)}")
print(f"Masks: {masks.shape}, {type(masks)}")
# Data augmentation pipeline
transform = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomResizedCrop((416, 416), scale=(0.8, 1.0)),
v2.ColorJitter(brightness=0.2, contrast=0.2),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply transforms (all types are handled automatically)
aug_image, aug_boxes, aug_masks = transform(image, boxes, masks)
print("\nAfter augmentation:")
print(f"Image: {aug_image.shape}, {type(aug_image)}")
print(f"Boxes: {aug_boxes.shape}, {type(aug_boxes)}")
print(f"Masks: {aug_masks.shape}, {type(aug_masks)}")
print(f"Canvas size updated: {aug_boxes.canvas_size}")
return aug_image, aug_boxes, aug_masks, labels
# Run detection pipeline
processed_data = detection_pipeline()Install with Tessl CLI
npx tessl i tessl/pypi-torchvision