0
# Utils
1
2
TorchVision utilities provide essential functions for image visualization, tensor manipulation, and drawing operations. These utilities are particularly useful for debugging, result visualization, and creating publication-quality figures from computer vision model outputs.
3
4
## Capabilities
5
6
### Image Grid and Visualization
7
8
Functions for creating image grids and saving tensor images to files.
9
10
```python { .api }
11
def make_grid(tensor, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0):
12
"""
13
Make a grid of images from a tensor.
14
15
Args:
16
tensor (Tensor): 4D mini-batch tensor of shape (B x C x H x W)
17
or list of images all of same size
18
nrow (int): Number of images displayed in each row of the grid
19
padding (int): Amount of padding between images
20
normalize (bool): If True, shift image to range (0, 1) by subtracting
21
minimum and dividing by maximum
22
value_range (tuple, optional): Tuple (min, max) for normalization
23
scale_each (bool): If True, scale each image independently
24
pad_value (float): Value for padding pixels
25
26
Returns:
27
Tensor: Image grid tensor of shape (3 x H x W)
28
"""
29
30
def save_image(tensor, fp, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0, format=None):
31
"""
32
Save tensor as image file.
33
34
Args:
35
tensor (Tensor): Image tensor to save
36
fp (str or file object): File path or file object to write to
37
nrow (int): Number of images displayed in each row
38
padding (int): Amount of padding between images
39
normalize (bool): If True, shift image to range (0, 1)
40
value_range (tuple, optional): Tuple (min, max) for normalization
41
scale_each (bool): If True, scale each image independently
42
pad_value (float): Value for padding pixels
43
format (str, optional): Image format to use ('PNG', 'JPEG', etc.)
44
"""
45
```
46
47
### Bounding Box Visualization
48
49
Functions for drawing and visualizing object detection results.
50
51
```python { .api }
52
def draw_bounding_boxes(image: torch.Tensor, boxes: torch.Tensor, labels=None, colors=None, fill: bool = False, width: int = 1, font=None, font_size: int = 10):
53
"""
54
Draw bounding boxes on image.
55
56
Args:
57
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
58
boxes (Tensor): Bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]
59
labels (list, optional): List of labels for each bounding box
60
colors (list, optional): List of colors for each bounding box
61
fill (bool): If True, fill bounding boxes with color
62
width (int): Width of bounding box lines
63
font (str, optional): Font name for labels
64
font_size (int): Font size for labels
65
66
Returns:
67
Tensor: Image tensor with drawn bounding boxes
68
"""
69
```
70
71
### Segmentation Mask Visualization
72
73
Functions for overlaying segmentation masks on images.
74
75
```python { .api }
76
def draw_segmentation_masks(image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors=None):
77
"""
78
Draw segmentation masks on image.
79
80
Args:
81
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
82
masks (Tensor): Boolean masks tensor of shape (N, H, W) where N is number of masks
83
alpha (float): Transparency level for masks (0.0 fully transparent, 1.0 fully opaque)
84
colors (list, optional): List of colors for each mask. If None, generates random colors
85
86
Returns:
87
Tensor: Image tensor with overlaid segmentation masks
88
"""
89
```
90
91
### Keypoint Visualization
92
93
Functions for drawing keypoints and pose estimation results.
94
95
```python { .api }
96
def draw_keypoints(image: torch.Tensor, keypoints: torch.Tensor, connectivity=None, colors=None, radius: int = 2, width: int = 3):
97
"""
98
Draw keypoints on image.
99
100
Args:
101
image (Tensor): Image tensor of shape (3, H, W) and dtype uint8
102
keypoints (Tensor): Keypoints tensor of shape (N, K, 3) where N is number of instances,
103
K is number of keypoints, and last dim is [x, y, visibility]
104
connectivity (list, optional): List of connections between keypoints as pairs of indices
105
colors (list, optional): List of colors for keypoints and connections
106
radius (int): Radius of keypoint circles
107
width (int): Width of connection lines
108
109
Returns:
110
Tensor: Image tensor with drawn keypoints and connections
111
"""
112
```
113
114
### Optical Flow Visualization
115
116
Functions for visualizing optical flow fields.
117
118
```python { .api }
119
def flow_to_image(flow: torch.Tensor):
120
"""
121
Convert optical flow to RGB image representation.
122
123
Args:
124
flow (Tensor): Optical flow tensor of shape (2, H, W) where first channel
125
is horizontal flow and second channel is vertical flow
126
127
Returns:
128
Tensor: RGB image tensor of shape (3, H, W) representing flow field
129
using color coding (hue for direction, saturation for magnitude)
130
"""
131
```
132
133
### Internal Utilities
134
135
Internal utility functions used by other TorchVision components.
136
137
```python { .api }
138
def _Image_fromarray(ndarray, mode=None):
139
"""
140
Internal PIL Image creation function.
141
142
Args:
143
ndarray: NumPy array to convert to PIL Image
144
mode (str, optional): PIL image mode
145
146
Returns:
147
PIL Image: Created PIL Image object
148
"""
149
```
150
151
## Usage Examples
152
153
### Creating Image Grids
154
155
```python
156
import torch
157
import torchvision.utils as utils
158
from torchvision import transforms
159
import matplotlib.pyplot as plt
160
161
# Create batch of random images (simulating model outputs)
162
batch_size, channels, height, width = 16, 3, 64, 64
163
images = torch.randint(0, 256, (batch_size, channels, height, width), dtype=torch.uint8)
164
165
# Create image grid
166
grid = utils.make_grid(images, nrow=4, padding=2, normalize=True)
167
168
# Display using matplotlib
169
plt.figure(figsize=(10, 10))
170
plt.imshow(grid.permute(1, 2, 0))
171
plt.axis('off')
172
plt.show()
173
174
# Save grid to file
175
utils.save_image(images, 'output_grid.png', nrow=4, padding=2, normalize=True)
176
```
177
178
### Visualizing Object Detection Results
179
180
```python
181
import torch
182
import torchvision.utils as utils
183
from PIL import Image
184
import torchvision.transforms as transforms
185
186
# Load and prepare image
187
image = Image.open('image.jpg')
188
transform = transforms.ToTensor()
189
image_tensor = transform(image)
190
image_uint8 = (image_tensor * 255).byte()
191
192
# Example detection results (x1, y1, x2, y2 format)
193
boxes = torch.tensor([
194
[50, 50, 200, 150], # First object
195
[300, 100, 450, 250], # Second object
196
[100, 300, 250, 400] # Third object
197
])
198
199
# Labels for detected objects
200
labels = ['person', 'car', 'dog']
201
202
# Colors for bounding boxes (optional)
203
colors = ['red', 'blue', 'green']
204
205
# Draw bounding boxes
206
result = utils.draw_bounding_boxes(
207
image_uint8,
208
boxes,
209
labels=labels,
210
colors=colors,
211
width=3,
212
font_size=20
213
)
214
215
# Convert back to PIL and display
216
result_pil = transforms.ToPILImage()(result)
217
result_pil.show()
218
219
# Save result
220
result_pil.save('detection_result.jpg')
221
```
222
223
### Visualizing Segmentation Masks
224
225
```python
226
import torch
227
import torchvision.utils as utils
228
from torchvision import transforms
229
230
# Load image
231
image_tensor = torch.randint(0, 256, (3, 300, 300), dtype=torch.uint8)
232
233
# Create example segmentation masks
234
mask1 = torch.zeros(300, 300, dtype=torch.bool)
235
mask1[50:150, 50:150] = True # Square mask
236
237
mask2 = torch.zeros(300, 300, dtype=torch.bool)
238
mask2[200:280, 200:280] = True # Another square mask
239
240
masks = torch.stack([mask1, mask2])
241
242
# Draw masks on image
243
result = utils.draw_segmentation_masks(
244
image_tensor,
245
masks,
246
alpha=0.7,
247
colors=['red', 'blue']
248
)
249
250
# Display result
251
result_pil = transforms.ToPILImage()(result)
252
result_pil.show()
253
```
254
255
### Visualizing Keypoints
256
257
```python
258
import torch
259
import torchvision.utils as utils
260
from torchvision import transforms
261
262
# Create example image
263
image = torch.randint(0, 256, (3, 400, 400), dtype=torch.uint8)
264
265
# Example keypoints for human pose (17 keypoints in COCO format)
266
# Shape: (num_people, num_keypoints, 3) where last dim is [x, y, visibility]
267
keypoints = torch.tensor([
268
[
269
[200, 100, 1], # nose
270
[190, 120, 1], # left eye
271
[210, 120, 1], # right eye
272
[180, 130, 1], # left ear
273
[220, 130, 1], # right ear
274
[170, 200, 1], # left shoulder
275
[230, 200, 1], # right shoulder
276
[160, 280, 1], # left elbow
277
[240, 280, 1], # right elbow
278
[150, 350, 1], # left wrist
279
[250, 350, 1], # right wrist
280
[180, 300, 1], # left hip
281
[220, 300, 1], # right hip
282
[175, 360, 1], # left knee
283
[225, 360, 1], # right knee
284
[170, 390, 1], # left ankle
285
[230, 390, 1], # right ankle
286
]
287
], dtype=torch.float)
288
289
# Define skeleton connections (COCO format)
290
connectivity = [
291
(0, 1), (0, 2), # nose to eyes
292
(1, 3), (2, 4), # eyes to ears
293
(5, 6), # shoulders
294
(5, 7), (7, 9), # left arm
295
(6, 8), (8, 10), # right arm
296
(5, 11), (6, 12), # shoulders to hips
297
(11, 12), # hips
298
(11, 13), (13, 15), # left leg
299
(12, 14), (14, 16), # right leg
300
]
301
302
# Draw keypoints
303
result = utils.draw_keypoints(
304
image,
305
keypoints,
306
connectivity=connectivity,
307
colors=['red'] * len(connectivity),
308
radius=5,
309
width=2
310
)
311
312
# Display result
313
result_pil = transforms.ToPILImage()(result)
314
result_pil.show()
315
```
316
317
### Optical Flow Visualization
318
319
```python
320
import torch
321
import torchvision.utils as utils
322
from torchvision import transforms
323
import numpy as np
324
325
# Create synthetic optical flow field
326
height, width = 256, 256
327
y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
328
329
# Create circular flow pattern
330
center_x, center_y = width // 2, height // 2
331
dx = -(y - center_y) * 0.1
332
dy = (x - center_x) * 0.1
333
334
# Convert to tensor
335
flow = torch.tensor(np.stack([dx, dy]), dtype=torch.float32)
336
337
# Convert flow to RGB image
338
flow_image = utils.flow_to_image(flow)
339
340
# Display flow visualization
341
flow_pil = transforms.ToPILImage()(flow_image)
342
flow_pil.show()
343
344
# Save flow visualization
345
flow_pil.save('optical_flow.png')
346
```
347
348
### Batch Visualization Pipeline
349
350
```python
351
import torch
352
import torchvision.utils as utils
353
from torchvision import transforms
354
import matplotlib.pyplot as plt
355
356
def visualize_batch_predictions(images, predictions, labels, num_images=8):
357
"""
358
Visualize batch of images with predictions and ground truth labels.
359
360
Args:
361
images: Batch of images tensor
362
predictions: Model predictions
363
labels: Ground truth labels
364
num_images: Number of images to visualize
365
"""
366
# Select subset of images
367
images = images[:num_images]
368
predictions = predictions[:num_images]
369
labels = labels[:num_images]
370
371
# Denormalize images (assuming ImageNet normalization)
372
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
373
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
374
images = images * std + mean
375
images = torch.clamp(images, 0, 1)
376
377
# Create grid
378
grid = utils.make_grid(images, nrow=4, padding=2)
379
380
# Display
381
plt.figure(figsize=(12, 8))
382
plt.imshow(grid.permute(1, 2, 0))
383
plt.axis('off')
384
385
# Add prediction vs ground truth info
386
pred_classes = torch.argmax(predictions, dim=1)
387
title = "Predictions vs Ground Truth\n"
388
for i in range(num_images):
389
title += f"Img{i+1}: Pred={pred_classes[i].item()}, GT={labels[i].item()} "
390
if i % 4 == 3:
391
title += "\n"
392
393
plt.title(title)
394
plt.tight_layout()
395
plt.show()
396
397
# Example usage
398
batch_images = torch.randn(16, 3, 224, 224)
399
batch_predictions = torch.randn(16, 10) # 10 classes
400
batch_labels = torch.randint(0, 10, (16,))
401
402
visualize_batch_predictions(batch_images, batch_predictions, batch_labels)
403
```
404
405
### Custom Visualization Functions
406
407
```python
408
import torch
409
import torchvision.utils as utils
410
from torchvision import transforms
411
412
def create_comparison_grid(original, processed, labels=None):
413
"""
414
Create side-by-side comparison of original and processed images.
415
416
Args:
417
original: Batch of original images
418
processed: Batch of processed images
419
labels: Optional labels for images
420
"""
421
batch_size = original.size(0)
422
423
# Interleave original and processed images
424
comparison = torch.zeros(batch_size * 2, *original.shape[1:])
425
comparison[0::2] = original
426
comparison[1::2] = processed
427
428
# Create grid with 2 columns (original, processed)
429
grid = utils.make_grid(comparison, nrow=2, padding=2, normalize=True)
430
431
return grid
432
433
# Example: Before and after augmentation
434
original_images = torch.randint(0, 256, (4, 3, 128, 128), dtype=torch.uint8)
435
436
# Apply some processing (e.g., color jitter)
437
from torchvision.transforms import ColorJitter
438
jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3)
439
processed_images = torch.stack([jitter(transforms.ToPILImage()(img)) for img in original_images])
440
processed_images = torch.stack([transforms.ToTensor()(img) for img in processed_images])
441
processed_images = (processed_images * 255).byte()
442
443
# Create comparison
444
comparison_grid = create_comparison_grid(original_images, processed_images)
445
446
# Display
447
comparison_pil = transforms.ToPILImage()(comparison_grid)
448
comparison_pil.show()
449
```