CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchvision

Computer vision library for PyTorch with datasets, model architectures, and image/video transforms.

Overview
Eval results
Files

utils.mddocs/

Utils

TorchVision utilities provide essential functions for image visualization, tensor manipulation, and drawing operations. These utilities are particularly useful for debugging, result visualization, and creating publication-quality figures from computer vision model outputs.

Capabilities

Image Grid and Visualization

Functions for creating image grids and saving tensor images to files.

def make_grid(tensor, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0):
    """
    Make a grid of images from a tensor.
    
    Args:
        tensor (Tensor): 4D mini-batch tensor of shape (B x C x H x W)
                        or list of images all of same size
        nrow (int): Number of images displayed in each row of the grid
        padding (int): Amount of padding between images
        normalize (bool): If True, shift image to range (0, 1) by subtracting 
                         minimum and dividing by maximum
        value_range (tuple, optional): Tuple (min, max) for normalization
        scale_each (bool): If True, scale each image independently
        pad_value (float): Value for padding pixels
        
    Returns:
        Tensor: Image grid tensor of shape (3 x H x W)
    """

def save_image(tensor, fp, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0, format=None):
    """
    Save tensor as image file.
    
    Args:
        tensor (Tensor): Image tensor to save
        fp (str or file object): File path or file object to write to
        nrow (int): Number of images displayed in each row
        padding (int): Amount of padding between images  
        normalize (bool): If True, shift image to range (0, 1)
        value_range (tuple, optional): Tuple (min, max) for normalization
        scale_each (bool): If True, scale each image independently
        pad_value (float): Value for padding pixels
        format (str, optional): Image format to use ('PNG', 'JPEG', etc.)
    """

Bounding Box Visualization

Functions for drawing and visualizing object detection results.

def draw_bounding_boxes(image: torch.Tensor, boxes: torch.Tensor, labels=None, colors=None, fill: bool = False, width: int = 1, font=None, font_size: int = 10):
    """
    Draw bounding boxes on image.
    
    Args:
        image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
        boxes (Tensor): Bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]
        labels (list, optional): List of labels for each bounding box
        colors (list, optional): List of colors for each bounding box
        fill (bool): If True, fill bounding boxes with color
        width (int): Width of bounding box lines
        font (str, optional): Font name for labels
        font_size (int): Font size for labels
        
    Returns:
        Tensor: Image tensor with drawn bounding boxes
    """

Segmentation Mask Visualization

Functions for overlaying segmentation masks on images.

def draw_segmentation_masks(image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors=None):
    """
    Draw segmentation masks on image.
    
    Args:
        image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
        masks (Tensor): Boolean masks tensor of shape (N, H, W) where N is number of masks
        alpha (float): Transparency level for masks (0.0 fully transparent, 1.0 fully opaque)
        colors (list, optional): List of colors for each mask. If None, generates random colors
        
    Returns:
        Tensor: Image tensor with overlaid segmentation masks
    """

Keypoint Visualization

Functions for drawing keypoints and pose estimation results.

def draw_keypoints(image: torch.Tensor, keypoints: torch.Tensor, connectivity=None, colors=None, radius: int = 2, width: int = 3):
    """
    Draw keypoints on image.
    
    Args:
        image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
        keypoints (Tensor): Keypoints tensor of shape (N, K, 3) where N is number of instances,
                           K is number of keypoints, and last dim is [x, y, visibility]
        connectivity (list, optional): List of connections between keypoints as pairs of indices
        colors (list, optional): List of colors for keypoints and connections
        radius (int): Radius of keypoint circles
        width (int): Width of connection lines
        
    Returns:
        Tensor: Image tensor with drawn keypoints and connections
    """

Optical Flow Visualization

Functions for visualizing optical flow fields.

def flow_to_image(flow: torch.Tensor):
    """
    Convert optical flow to RGB image representation.
    
    Args:
        flow (Tensor): Optical flow tensor of shape (2, H, W) where first channel
                      is horizontal flow and second channel is vertical flow
                      
    Returns:
        Tensor: RGB image tensor of shape (3, H, W) representing flow field
                using color coding (hue for direction, saturation for magnitude)
    """

Internal Utilities

Internal utility functions used by other TorchVision components.

def _Image_fromarray(ndarray, mode=None):
    """
    Internal PIL Image creation function.
    
    Args:
        ndarray: NumPy array to convert to PIL Image
        mode (str, optional): PIL image mode
        
    Returns:
        PIL Image: Created PIL Image object
    """

Usage Examples

Creating Image Grids

import torch
import torchvision.utils as utils
from torchvision import transforms
import matplotlib.pyplot as plt

# Create batch of random images (simulating model outputs)
batch_size, channels, height, width = 16, 3, 64, 64
images = torch.randint(0, 256, (batch_size, channels, height, width), dtype=torch.uint8)

# Create image grid
grid = utils.make_grid(images, nrow=4, padding=2, normalize=True)

# Display using matplotlib
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0))
plt.axis('off')
plt.show()

# Save grid to file
utils.save_image(images, 'output_grid.png', nrow=4, padding=2, normalize=True)

Visualizing Object Detection Results

import torch
import torchvision.utils as utils
from PIL import Image
import torchvision.transforms as transforms

# Load and prepare image
image = Image.open('image.jpg')
transform = transforms.ToTensor()
image_tensor = transform(image)
image_uint8 = (image_tensor * 255).byte()

# Example detection results (x1, y1, x2, y2 format)
boxes = torch.tensor([
    [50, 50, 200, 150],   # First object
    [300, 100, 450, 250], # Second object
    [100, 300, 250, 400]  # Third object
])

# Labels for detected objects
labels = ['person', 'car', 'dog']

# Colors for bounding boxes (optional)
colors = ['red', 'blue', 'green']

# Draw bounding boxes
result = utils.draw_bounding_boxes(
    image_uint8, 
    boxes, 
    labels=labels,
    colors=colors,
    width=3,
    font_size=20
)

# Convert back to PIL and display
result_pil = transforms.ToPILImage()(result)
result_pil.show()

# Save result
result_pil.save('detection_result.jpg')

Visualizing Segmentation Masks

import torch
import torchvision.utils as utils
from torchvision import transforms

# Load image
image_tensor = torch.randint(0, 256, (3, 300, 300), dtype=torch.uint8)

# Create example segmentation masks
mask1 = torch.zeros(300, 300, dtype=torch.bool)
mask1[50:150, 50:150] = True  # Square mask

mask2 = torch.zeros(300, 300, dtype=torch.bool) 
mask2[200:280, 200:280] = True  # Another square mask

masks = torch.stack([mask1, mask2])

# Draw masks on image
result = utils.draw_segmentation_masks(
    image_tensor,
    masks,
    alpha=0.7,
    colors=['red', 'blue']
)

# Display result
result_pil = transforms.ToPILImage()(result)
result_pil.show()

Visualizing Keypoints

import torch
import torchvision.utils as utils
from torchvision import transforms

# Create example image
image = torch.randint(0, 256, (3, 400, 400), dtype=torch.uint8)

# Example keypoints for human pose (17 keypoints in COCO format)
# Shape: (num_people, num_keypoints, 3) where last dim is [x, y, visibility]
keypoints = torch.tensor([
    [
        [200, 100, 1],  # nose
        [190, 120, 1],  # left eye
        [210, 120, 1],  # right eye
        [180, 130, 1],  # left ear
        [220, 130, 1],  # right ear
        [170, 200, 1],  # left shoulder
        [230, 200, 1],  # right shoulder
        [160, 280, 1],  # left elbow
        [240, 280, 1],  # right elbow
        [150, 350, 1],  # left wrist
        [250, 350, 1],  # right wrist
        [180, 300, 1],  # left hip
        [220, 300, 1],  # right hip
        [175, 360, 1],  # left knee
        [225, 360, 1],  # right knee
        [170, 390, 1],  # left ankle
        [230, 390, 1],  # right ankle
    ]
], dtype=torch.float)

# Define skeleton connections (COCO format)
connectivity = [
    (0, 1), (0, 2),          # nose to eyes
    (1, 3), (2, 4),          # eyes to ears
    (5, 6),                  # shoulders
    (5, 7), (7, 9),          # left arm
    (6, 8), (8, 10),         # right arm
    (5, 11), (6, 12),        # shoulders to hips
    (11, 12),                # hips
    (11, 13), (13, 15),      # left leg
    (12, 14), (14, 16),      # right leg
]

# Draw keypoints
result = utils.draw_keypoints(
    image,
    keypoints,
    connectivity=connectivity,
    colors=['red'] * len(connectivity),
    radius=5,
    width=2
)

# Display result
result_pil = transforms.ToPILImage()(result)
result_pil.show()

Optical Flow Visualization

import torch
import torchvision.utils as utils
from torchvision import transforms
import numpy as np

# Create synthetic optical flow field
height, width = 256, 256
y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')

# Create circular flow pattern
center_x, center_y = width // 2, height // 2
dx = -(y - center_y) * 0.1
dy = (x - center_x) * 0.1

# Convert to tensor
flow = torch.tensor(np.stack([dx, dy]), dtype=torch.float32)

# Convert flow to RGB image
flow_image = utils.flow_to_image(flow)

# Display flow visualization
flow_pil = transforms.ToPILImage()(flow_image)
flow_pil.show()

# Save flow visualization
flow_pil.save('optical_flow.png')

Batch Visualization Pipeline

import torch
import torchvision.utils as utils
from torchvision import transforms
import matplotlib.pyplot as plt

def visualize_batch_predictions(images, predictions, labels, num_images=8):
    """
    Visualize batch of images with predictions and ground truth labels.
    
    Args:
        images: Batch of images tensor
        predictions: Model predictions
        labels: Ground truth labels
        num_images: Number of images to visualize
    """
    # Select subset of images
    images = images[:num_images]
    predictions = predictions[:num_images]
    labels = labels[:num_images]
    
    # Denormalize images (assuming ImageNet normalization)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images = images * std + mean
    images = torch.clamp(images, 0, 1)
    
    # Create grid
    grid = utils.make_grid(images, nrow=4, padding=2)
    
    # Display
    plt.figure(figsize=(12, 8))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    
    # Add prediction vs ground truth info
    pred_classes = torch.argmax(predictions, dim=1)
    title = "Predictions vs Ground Truth\n"
    for i in range(num_images):
        title += f"Img{i+1}: Pred={pred_classes[i].item()}, GT={labels[i].item()} "
        if i % 4 == 3:
            title += "\n"
    
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Example usage
batch_images = torch.randn(16, 3, 224, 224)
batch_predictions = torch.randn(16, 10)  # 10 classes
batch_labels = torch.randint(0, 10, (16,))

visualize_batch_predictions(batch_images, batch_predictions, batch_labels)

Custom Visualization Functions

import torch
import torchvision.utils as utils
from torchvision import transforms

def create_comparison_grid(original, processed, labels=None):
    """
    Create side-by-side comparison of original and processed images.
    
    Args:
        original: Batch of original images
        processed: Batch of processed images  
        labels: Optional labels for images
    """
    batch_size = original.size(0)
    
    # Interleave original and processed images
    comparison = torch.zeros(batch_size * 2, *original.shape[1:])
    comparison[0::2] = original
    comparison[1::2] = processed
    
    # Create grid with 2 columns (original, processed)
    grid = utils.make_grid(comparison, nrow=2, padding=2, normalize=True)
    
    return grid

# Example: Before and after augmentation
original_images = torch.randint(0, 256, (4, 3, 128, 128), dtype=torch.uint8)

# Apply some processing (e.g., color jitter)
from torchvision.transforms import ColorJitter
jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3)
processed_images = torch.stack([jitter(transforms.ToPILImage()(img)) for img in original_images])
processed_images = torch.stack([transforms.ToTensor()(img) for img in processed_images])
processed_images = (processed_images * 255).byte()

# Create comparison
comparison_grid = create_comparison_grid(original_images, processed_images)

# Display
comparison_pil = transforms.ToPILImage()(comparison_grid)
comparison_pil.show()

Install with Tessl CLI

npx tessl i tessl/pypi-torchvision

docs

datasets.md

index.md

io.md

models.md

ops.md

transforms.md

tv_tensors.md

utils.md

tile.json