or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

datasets.mdindex.mdio.mdmodels.mdops.mdtransforms.mdtv_tensors.mdutils.md

utils.mddocs/

0

# Utils

1

2

TorchVision utilities provide essential functions for image visualization, tensor manipulation, and drawing operations. These utilities are particularly useful for debugging, result visualization, and creating publication-quality figures from computer vision model outputs.

3

4

## Capabilities

5

6

### Image Grid and Visualization

7

8

Functions for creating image grids and saving tensor images to files.

9

10

```python { .api }

11

def make_grid(tensor, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0):

12

"""

13

Make a grid of images from a tensor.

14

15

Args:

16

tensor (Tensor): 4D mini-batch tensor of shape (B x C x H x W)

17

or list of images all of same size

18

nrow (int): Number of images displayed in each row of the grid

19

padding (int): Amount of padding between images

20

normalize (bool): If True, shift image to range (0, 1) by subtracting

21

minimum and dividing by maximum

22

value_range (tuple, optional): Tuple (min, max) for normalization

23

scale_each (bool): If True, scale each image independently

24

pad_value (float): Value for padding pixels

25

26

Returns:

27

Tensor: Image grid tensor of shape (3 x H x W)

28

"""

29

30

def save_image(tensor, fp, nrow: int = 8, padding: int = 2, normalize: bool = False, value_range=None, scale_each: bool = False, pad_value: float = 0.0, format=None):

31

"""

32

Save tensor as image file.

33

34

Args:

35

tensor (Tensor): Image tensor to save

36

fp (str or file object): File path or file object to write to

37

nrow (int): Number of images displayed in each row

38

padding (int): Amount of padding between images

39

normalize (bool): If True, shift image to range (0, 1)

40

value_range (tuple, optional): Tuple (min, max) for normalization

41

scale_each (bool): If True, scale each image independently

42

pad_value (float): Value for padding pixels

43

format (str, optional): Image format to use ('PNG', 'JPEG', etc.)

44

"""

45

```

46

47

### Bounding Box Visualization

48

49

Functions for drawing and visualizing object detection results.

50

51

```python { .api }

52

def draw_bounding_boxes(image: torch.Tensor, boxes: torch.Tensor, labels=None, colors=None, fill: bool = False, width: int = 1, font=None, font_size: int = 10):

53

"""

54

Draw bounding boxes on image.

55

56

Args:

57

image (Tensor): Image tensor of shape (3, H, W) and dtype uint8

58

boxes (Tensor): Bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]

59

labels (list, optional): List of labels for each bounding box

60

colors (list, optional): List of colors for each bounding box

61

fill (bool): If True, fill bounding boxes with color

62

width (int): Width of bounding box lines

63

font (str, optional): Font name for labels

64

font_size (int): Font size for labels

65

66

Returns:

67

Tensor: Image tensor with drawn bounding boxes

68

"""

69

```

70

71

### Segmentation Mask Visualization

72

73

Functions for overlaying segmentation masks on images.

74

75

```python { .api }

76

def draw_segmentation_masks(image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors=None):

77

"""

78

Draw segmentation masks on image.

79

80

Args:

81

image (Tensor): Image tensor of shape (3, H, W) and dtype uint8

82

masks (Tensor): Boolean masks tensor of shape (N, H, W) where N is number of masks

83

alpha (float): Transparency level for masks (0.0 fully transparent, 1.0 fully opaque)

84

colors (list, optional): List of colors for each mask. If None, generates random colors

85

86

Returns:

87

Tensor: Image tensor with overlaid segmentation masks

88

"""

89

```

90

91

### Keypoint Visualization

92

93

Functions for drawing keypoints and pose estimation results.

94

95

```python { .api }

96

def draw_keypoints(image: torch.Tensor, keypoints: torch.Tensor, connectivity=None, colors=None, radius: int = 2, width: int = 3):

97

"""

98

Draw keypoints on image.

99

100

Args:

101

image (Tensor): Image tensor of shape (3, H, W) and dtype uint8

102

keypoints (Tensor): Keypoints tensor of shape (N, K, 3) where N is number of instances,

103

K is number of keypoints, and last dim is [x, y, visibility]

104

connectivity (list, optional): List of connections between keypoints as pairs of indices

105

colors (list, optional): List of colors for keypoints and connections

106

radius (int): Radius of keypoint circles

107

width (int): Width of connection lines

108

109

Returns:

110

Tensor: Image tensor with drawn keypoints and connections

111

"""

112

```

113

114

### Optical Flow Visualization

115

116

Functions for visualizing optical flow fields.

117

118

```python { .api }

119

def flow_to_image(flow: torch.Tensor):

120

"""

121

Convert optical flow to RGB image representation.

122

123

Args:

124

flow (Tensor): Optical flow tensor of shape (2, H, W) where first channel

125

is horizontal flow and second channel is vertical flow

126

127

Returns:

128

Tensor: RGB image tensor of shape (3, H, W) representing flow field

129

using color coding (hue for direction, saturation for magnitude)

130

"""

131

```

132

133

### Internal Utilities

134

135

Internal utility functions used by other TorchVision components.

136

137

```python { .api }

138

def _Image_fromarray(ndarray, mode=None):

139

"""

140

Internal PIL Image creation function.

141

142

Args:

143

ndarray: NumPy array to convert to PIL Image

144

mode (str, optional): PIL image mode

145

146

Returns:

147

PIL Image: Created PIL Image object

148

"""

149

```

150

151

## Usage Examples

152

153

### Creating Image Grids

154

155

```python

156

import torch

157

import torchvision.utils as utils

158

from torchvision import transforms

159

import matplotlib.pyplot as plt

160

161

# Create batch of random images (simulating model outputs)

162

batch_size, channels, height, width = 16, 3, 64, 64

163

images = torch.randint(0, 256, (batch_size, channels, height, width), dtype=torch.uint8)

164

165

# Create image grid

166

grid = utils.make_grid(images, nrow=4, padding=2, normalize=True)

167

168

# Display using matplotlib

169

plt.figure(figsize=(10, 10))

170

plt.imshow(grid.permute(1, 2, 0))

171

plt.axis('off')

172

plt.show()

173

174

# Save grid to file

175

utils.save_image(images, 'output_grid.png', nrow=4, padding=2, normalize=True)

176

```

177

178

### Visualizing Object Detection Results

179

180

```python

181

import torch

182

import torchvision.utils as utils

183

from PIL import Image

184

import torchvision.transforms as transforms

185

186

# Load and prepare image

187

image = Image.open('image.jpg')

188

transform = transforms.ToTensor()

189

image_tensor = transform(image)

190

image_uint8 = (image_tensor * 255).byte()

191

192

# Example detection results (x1, y1, x2, y2 format)

193

boxes = torch.tensor([

194

[50, 50, 200, 150], # First object

195

[300, 100, 450, 250], # Second object

196

[100, 300, 250, 400] # Third object

197

])

198

199

# Labels for detected objects

200

labels = ['person', 'car', 'dog']

201

202

# Colors for bounding boxes (optional)

203

colors = ['red', 'blue', 'green']

204

205

# Draw bounding boxes

206

result = utils.draw_bounding_boxes(

207

image_uint8,

208

boxes,

209

labels=labels,

210

colors=colors,

211

width=3,

212

font_size=20

213

)

214

215

# Convert back to PIL and display

216

result_pil = transforms.ToPILImage()(result)

217

result_pil.show()

218

219

# Save result

220

result_pil.save('detection_result.jpg')

221

```

222

223

### Visualizing Segmentation Masks

224

225

```python

226

import torch

227

import torchvision.utils as utils

228

from torchvision import transforms

229

230

# Load image

231

image_tensor = torch.randint(0, 256, (3, 300, 300), dtype=torch.uint8)

232

233

# Create example segmentation masks

234

mask1 = torch.zeros(300, 300, dtype=torch.bool)

235

mask1[50:150, 50:150] = True # Square mask

236

237

mask2 = torch.zeros(300, 300, dtype=torch.bool)

238

mask2[200:280, 200:280] = True # Another square mask

239

240

masks = torch.stack([mask1, mask2])

241

242

# Draw masks on image

243

result = utils.draw_segmentation_masks(

244

image_tensor,

245

masks,

246

alpha=0.7,

247

colors=['red', 'blue']

248

)

249

250

# Display result

251

result_pil = transforms.ToPILImage()(result)

252

result_pil.show()

253

```

254

255

### Visualizing Keypoints

256

257

```python

258

import torch

259

import torchvision.utils as utils

260

from torchvision import transforms

261

262

# Create example image

263

image = torch.randint(0, 256, (3, 400, 400), dtype=torch.uint8)

264

265

# Example keypoints for human pose (17 keypoints in COCO format)

266

# Shape: (num_people, num_keypoints, 3) where last dim is [x, y, visibility]

267

keypoints = torch.tensor([

268

[

269

[200, 100, 1], # nose

270

[190, 120, 1], # left eye

271

[210, 120, 1], # right eye

272

[180, 130, 1], # left ear

273

[220, 130, 1], # right ear

274

[170, 200, 1], # left shoulder

275

[230, 200, 1], # right shoulder

276

[160, 280, 1], # left elbow

277

[240, 280, 1], # right elbow

278

[150, 350, 1], # left wrist

279

[250, 350, 1], # right wrist

280

[180, 300, 1], # left hip

281

[220, 300, 1], # right hip

282

[175, 360, 1], # left knee

283

[225, 360, 1], # right knee

284

[170, 390, 1], # left ankle

285

[230, 390, 1], # right ankle

286

]

287

], dtype=torch.float)

288

289

# Define skeleton connections (COCO format)

290

connectivity = [

291

(0, 1), (0, 2), # nose to eyes

292

(1, 3), (2, 4), # eyes to ears

293

(5, 6), # shoulders

294

(5, 7), (7, 9), # left arm

295

(6, 8), (8, 10), # right arm

296

(5, 11), (6, 12), # shoulders to hips

297

(11, 12), # hips

298

(11, 13), (13, 15), # left leg

299

(12, 14), (14, 16), # right leg

300

]

301

302

# Draw keypoints

303

result = utils.draw_keypoints(

304

image,

305

keypoints,

306

connectivity=connectivity,

307

colors=['red'] * len(connectivity),

308

radius=5,

309

width=2

310

)

311

312

# Display result

313

result_pil = transforms.ToPILImage()(result)

314

result_pil.show()

315

```

316

317

### Optical Flow Visualization

318

319

```python

320

import torch

321

import torchvision.utils as utils

322

from torchvision import transforms

323

import numpy as np

324

325

# Create synthetic optical flow field

326

height, width = 256, 256

327

y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')

328

329

# Create circular flow pattern

330

center_x, center_y = width // 2, height // 2

331

dx = -(y - center_y) * 0.1

332

dy = (x - center_x) * 0.1

333

334

# Convert to tensor

335

flow = torch.tensor(np.stack([dx, dy]), dtype=torch.float32)

336

337

# Convert flow to RGB image

338

flow_image = utils.flow_to_image(flow)

339

340

# Display flow visualization

341

flow_pil = transforms.ToPILImage()(flow_image)

342

flow_pil.show()

343

344

# Save flow visualization

345

flow_pil.save('optical_flow.png')

346

```

347

348

### Batch Visualization Pipeline

349

350

```python

351

import torch

352

import torchvision.utils as utils

353

from torchvision import transforms

354

import matplotlib.pyplot as plt

355

356

def visualize_batch_predictions(images, predictions, labels, num_images=8):

357

"""

358

Visualize batch of images with predictions and ground truth labels.

359

360

Args:

361

images: Batch of images tensor

362

predictions: Model predictions

363

labels: Ground truth labels

364

num_images: Number of images to visualize

365

"""

366

# Select subset of images

367

images = images[:num_images]

368

predictions = predictions[:num_images]

369

labels = labels[:num_images]

370

371

# Denormalize images (assuming ImageNet normalization)

372

mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)

373

std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

374

images = images * std + mean

375

images = torch.clamp(images, 0, 1)

376

377

# Create grid

378

grid = utils.make_grid(images, nrow=4, padding=2)

379

380

# Display

381

plt.figure(figsize=(12, 8))

382

plt.imshow(grid.permute(1, 2, 0))

383

plt.axis('off')

384

385

# Add prediction vs ground truth info

386

pred_classes = torch.argmax(predictions, dim=1)

387

title = "Predictions vs Ground Truth\n"

388

for i in range(num_images):

389

title += f"Img{i+1}: Pred={pred_classes[i].item()}, GT={labels[i].item()} "

390

if i % 4 == 3:

391

title += "\n"

392

393

plt.title(title)

394

plt.tight_layout()

395

plt.show()

396

397

# Example usage

398

batch_images = torch.randn(16, 3, 224, 224)

399

batch_predictions = torch.randn(16, 10) # 10 classes

400

batch_labels = torch.randint(0, 10, (16,))

401

402

visualize_batch_predictions(batch_images, batch_predictions, batch_labels)

403

```

404

405

### Custom Visualization Functions

406

407

```python

408

import torch

409

import torchvision.utils as utils

410

from torchvision import transforms

411

412

def create_comparison_grid(original, processed, labels=None):

413

"""

414

Create side-by-side comparison of original and processed images.

415

416

Args:

417

original: Batch of original images

418

processed: Batch of processed images

419

labels: Optional labels for images

420

"""

421

batch_size = original.size(0)

422

423

# Interleave original and processed images

424

comparison = torch.zeros(batch_size * 2, *original.shape[1:])

425

comparison[0::2] = original

426

comparison[1::2] = processed

427

428

# Create grid with 2 columns (original, processed)

429

grid = utils.make_grid(comparison, nrow=2, padding=2, normalize=True)

430

431

return grid

432

433

# Example: Before and after augmentation

434

original_images = torch.randint(0, 256, (4, 3, 128, 128), dtype=torch.uint8)

435

436

# Apply some processing (e.g., color jitter)

437

from torchvision.transforms import ColorJitter

438

jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3)

439

processed_images = torch.stack([jitter(transforms.ToPILImage()(img)) for img in original_images])

440

processed_images = torch.stack([transforms.ToTensor()(img) for img in processed_images])

441

processed_images = (processed_images * 255).byte()

442

443

# Create comparison

444

comparison_grid = create_comparison_grid(original_images, processed_images)

445

446

# Display

447

comparison_pil = transforms.ToPILImage()(comparison_grid)

448

comparison_pil.show()

449

```