Computer vision library for PyTorch with datasets, model architectures, and image/video transforms.
TorchVision utilities provide essential functions for image visualization, tensor manipulation, and drawing operations. These utilities are particularly useful for debugging, result visualization, and creating publication-quality figures from computer vision model outputs.
Functions for creating image grids and saving tensor images to files.
def make_grid(tensor, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0):
"""
Make a grid of images from a tensor.
Args:
tensor (Tensor): 4D mini-batch tensor of shape (B x C x H x W)
or list of images all of same size
nrow (int): Number of images displayed in each row of the grid
padding (int): Amount of padding between images
normalize (bool): If True, shift image to range (0, 1) by subtracting
minimum and dividing by maximum
value_range (tuple, optional): Tuple (min, max) for normalization
scale_each (bool): If True, scale each image independently
pad_value (float): Value for padding pixels
Returns:
Tensor: Image grid tensor of shape (3 x H x W)
"""
def save_image(tensor, fp, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0, format=None):
"""
Save tensor as image file.
Args:
tensor (Tensor): Image tensor to save
fp (str or file object): File path or file object to write to
nrow (int): Number of images displayed in each row
padding (int): Amount of padding between images
normalize (bool): If True, shift image to range (0, 1)
value_range (tuple, optional): Tuple (min, max) for normalization
scale_each (bool): If True, scale each image independently
pad_value (float): Value for padding pixels
format (str, optional): Image format to use ('PNG', 'JPEG', etc.)
"""Functions for drawing and visualizing object detection results.
def draw_bounding_boxes(image: torch.Tensor, boxes: torch.Tensor, labels=None, colors=None, fill: bool = False, width: int = 1, font=None, font_size: int = 10):
"""
Draw bounding boxes on image.
Args:
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
boxes (Tensor): Bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]
labels (list, optional): List of labels for each bounding box
colors (list, optional): List of colors for each bounding box
fill (bool): If True, fill bounding boxes with color
width (int): Width of bounding box lines
font (str, optional): Font name for labels
font_size (int): Font size for labels
Returns:
Tensor: Image tensor with drawn bounding boxes
"""Functions for overlaying segmentation masks on images.
def draw_segmentation_masks(image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors=None):
"""
Draw segmentation masks on image.
Args:
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
masks (Tensor): Boolean masks tensor of shape (N, H, W) where N is number of masks
alpha (float): Transparency level for masks (0.0 fully transparent, 1.0 fully opaque)
colors (list, optional): List of colors for each mask. If None, generates random colors
Returns:
Tensor: Image tensor with overlaid segmentation masks
"""Functions for drawing keypoints and pose estimation results.
def draw_keypoints(image: torch.Tensor, keypoints: torch.Tensor, connectivity=None, colors=None, radius: int = 2, width: int = 3):
"""
Draw keypoints on image.
Args:
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
keypoints (Tensor): Keypoints tensor of shape (N, K, 3) where N is number of instances,
K is number of keypoints, and last dim is [x, y, visibility]
connectivity (list, optional): List of connections between keypoints as pairs of indices
colors (list, optional): List of colors for keypoints and connections
radius (int): Radius of keypoint circles
width (int): Width of connection lines
Returns:
Tensor: Image tensor with drawn keypoints and connections
"""Functions for visualizing optical flow fields.
def flow_to_image(flow: torch.Tensor):
"""
Convert optical flow to RGB image representation.
Args:
flow (Tensor): Optical flow tensor of shape (2, H, W) where first channel
is horizontal flow and second channel is vertical flow
Returns:
Tensor: RGB image tensor of shape (3, H, W) representing flow field
using color coding (hue for direction, saturation for magnitude)
"""Internal utility functions used by other TorchVision components.
def _Image_fromarray(ndarray, mode=None):
"""
Internal PIL Image creation function.
Args:
ndarray: NumPy array to convert to PIL Image
mode (str, optional): PIL image mode
Returns:
PIL Image: Created PIL Image object
"""import torch
import torchvision.utils as utils
from torchvision import transforms
import matplotlib.pyplot as plt
# Create batch of random images (simulating model outputs)
batch_size, channels, height, width = 16, 3, 64, 64
images = torch.randint(0, 256, (batch_size, channels, height, width), dtype=torch.uint8)
# Create image grid
grid = utils.make_grid(images, nrow=4, padding=2, normalize=True)
# Display using matplotlib
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0))
plt.axis('off')
plt.show()
# Save grid to file
utils.save_image(images, 'output_grid.png', nrow=4, padding=2, normalize=True)import torch
import torchvision.utils as utils
from PIL import Image
import torchvision.transforms as transforms
# Load and prepare image
image = Image.open('image.jpg')
transform = transforms.ToTensor()
image_tensor = transform(image)
image_uint8 = (image_tensor * 255).byte()
# Example detection results (x1, y1, x2, y2 format)
boxes = torch.tensor([
[50, 50, 200, 150], # First object
[300, 100, 450, 250], # Second object
[100, 300, 250, 400] # Third object
])
# Labels for detected objects
labels = ['person', 'car', 'dog']
# Colors for bounding boxes (optional)
colors = ['red', 'blue', 'green']
# Draw bounding boxes
result = utils.draw_bounding_boxes(
image_uint8,
boxes,
labels=labels,
colors=colors,
width=3,
font_size=20
)
# Convert back to PIL and display
result_pil = transforms.ToPILImage()(result)
result_pil.show()
# Save result
result_pil.save('detection_result.jpg')import torch
import torchvision.utils as utils
from torchvision import transforms
# Load image
image_tensor = torch.randint(0, 256, (3, 300, 300), dtype=torch.uint8)
# Create example segmentation masks
mask1 = torch.zeros(300, 300, dtype=torch.bool)
mask1[50:150, 50:150] = True # Square mask
mask2 = torch.zeros(300, 300, dtype=torch.bool)
mask2[200:280, 200:280] = True # Another square mask
masks = torch.stack([mask1, mask2])
# Draw masks on image
result = utils.draw_segmentation_masks(
image_tensor,
masks,
alpha=0.7,
colors=['red', 'blue']
)
# Display result
result_pil = transforms.ToPILImage()(result)
result_pil.show()import torch
import torchvision.utils as utils
from torchvision import transforms
# Create example image
image = torch.randint(0, 256, (3, 400, 400), dtype=torch.uint8)
# Example keypoints for human pose (17 keypoints in COCO format)
# Shape: (num_people, num_keypoints, 3) where last dim is [x, y, visibility]
keypoints = torch.tensor([
[
[200, 100, 1], # nose
[190, 120, 1], # left eye
[210, 120, 1], # right eye
[180, 130, 1], # left ear
[220, 130, 1], # right ear
[170, 200, 1], # left shoulder
[230, 200, 1], # right shoulder
[160, 280, 1], # left elbow
[240, 280, 1], # right elbow
[150, 350, 1], # left wrist
[250, 350, 1], # right wrist
[180, 300, 1], # left hip
[220, 300, 1], # right hip
[175, 360, 1], # left knee
[225, 360, 1], # right knee
[170, 390, 1], # left ankle
[230, 390, 1], # right ankle
]
], dtype=torch.float)
# Define skeleton connections (COCO format)
connectivity = [
(0, 1), (0, 2), # nose to eyes
(1, 3), (2, 4), # eyes to ears
(5, 6), # shoulders
(5, 7), (7, 9), # left arm
(6, 8), (8, 10), # right arm
(5, 11), (6, 12), # shoulders to hips
(11, 12), # hips
(11, 13), (13, 15), # left leg
(12, 14), (14, 16), # right leg
]
# Draw keypoints
result = utils.draw_keypoints(
image,
keypoints,
connectivity=connectivity,
colors=['red'] * len(connectivity),
radius=5,
width=2
)
# Display result
result_pil = transforms.ToPILImage()(result)
result_pil.show()import torch
import torchvision.utils as utils
from torchvision import transforms
import numpy as np
# Create synthetic optical flow field
height, width = 256, 256
y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
# Create circular flow pattern
center_x, center_y = width // 2, height // 2
dx = -(y - center_y) * 0.1
dy = (x - center_x) * 0.1
# Convert to tensor
flow = torch.tensor(np.stack([dx, dy]), dtype=torch.float32)
# Convert flow to RGB image
flow_image = utils.flow_to_image(flow)
# Display flow visualization
flow_pil = transforms.ToPILImage()(flow_image)
flow_pil.show()
# Save flow visualization
flow_pil.save('optical_flow.png')import torch
import torchvision.utils as utils
from torchvision import transforms
import matplotlib.pyplot as plt
def visualize_batch_predictions(images, predictions, labels, num_images=8):
"""
Visualize batch of images with predictions and ground truth labels.
Args:
images: Batch of images tensor
predictions: Model predictions
labels: Ground truth labels
num_images: Number of images to visualize
"""
# Select subset of images
images = images[:num_images]
predictions = predictions[:num_images]
labels = labels[:num_images]
# Denormalize images (assuming ImageNet normalization)
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
images = images * std + mean
images = torch.clamp(images, 0, 1)
# Create grid
grid = utils.make_grid(images, nrow=4, padding=2)
# Display
plt.figure(figsize=(12, 8))
plt.imshow(grid.permute(1, 2, 0))
plt.axis('off')
# Add prediction vs ground truth info
pred_classes = torch.argmax(predictions, dim=1)
title = "Predictions vs Ground Truth\n"
for i in range(num_images):
title += f"Img{i+1}: Pred={pred_classes[i].item()}, GT={labels[i].item()} "
if i % 4 == 3:
title += "\n"
plt.title(title)
plt.tight_layout()
plt.show()
# Example usage
batch_images = torch.randn(16, 3, 224, 224)
batch_predictions = torch.randn(16, 10) # 10 classes
batch_labels = torch.randint(0, 10, (16,))
visualize_batch_predictions(batch_images, batch_predictions, batch_labels)import torch
import torchvision.utils as utils
from torchvision import transforms
def create_comparison_grid(original, processed, labels=None):
"""
Create side-by-side comparison of original and processed images.
Args:
original: Batch of original images
processed: Batch of processed images
labels: Optional labels for images
"""
batch_size = original.size(0)
# Interleave original and processed images
comparison = torch.zeros(batch_size * 2, *original.shape[1:])
comparison[0::2] = original
comparison[1::2] = processed
# Create grid with 2 columns (original, processed)
grid = utils.make_grid(comparison, nrow=2, padding=2, normalize=True)
return grid
# Example: Before and after augmentation
original_images = torch.randint(0, 256, (4, 3, 128, 128), dtype=torch.uint8)
# Apply some processing (e.g., color jitter)
from torchvision.transforms import ColorJitter
jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3)
processed_images = torch.stack([jitter(transforms.ToPILImage()(img)) for img in original_images])
processed_images = torch.stack([transforms.ToTensor()(img) for img in processed_images])
processed_images = (processed_images * 255).byte()
# Create comparison
comparison_grid = create_comparison_grid(original_images, processed_images)
# Display
comparison_pil = transforms.ToPILImage()(comparison_grid)
comparison_pil.show()Install with Tessl CLI
npx tessl i tessl/pypi-torchvision