0
# TV Tensors
1
2
TorchVision TV Tensors provide enhanced tensor types that preserve metadata and semantics through transformations. These specialized tensors enable transforms to handle multiple data types (images, videos, bounding boxes, masks, keypoints) consistently while maintaining their specific properties and constraints.
3
4
## Capabilities
5
6
### Base TV Tensor
7
8
Foundation class for all TorchVision tensor types with enhanced metadata support.
9
10
```python { .api }
11
class TVTensor(torch.Tensor):
12
"""
13
Base class for all torchvision tensor types.
14
Extends torch.Tensor with metadata preservation through transformations.
15
16
Provides automatic wrapping and unwrapping of tensor operations while
17
maintaining type-specific metadata and constraints.
18
"""
19
20
def __new__(cls, data, **kwargs): ...
21
22
def wrap_like(self, other, **kwargs):
23
"""Wrap tensor with same type and metadata as another TVTensor."""
24
```
25
26
### Image Tensors
27
28
Enhanced image tensors that preserve image semantics and properties.
29
30
```python { .api }
31
class Image(TVTensor):
32
"""
33
Image tensor type with preserved image semantics.
34
35
Inherits from torch.Tensor and maintains image-specific properties
36
through transformations. Ensures operations maintain image constraints
37
like channel ordering and value ranges.
38
39
Args:
40
data: Image data as tensor, PIL Image, or numpy array
41
dtype: Data type for the tensor (default: inferred)
42
device: Device to place tensor on
43
requires_grad: Whether tensor requires gradients
44
45
Shape:
46
- (C, H, W): Single image with C channels, H height, W width
47
- (N, C, H, W): Batch of N images
48
"""
49
50
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
51
52
@property
53
def spatial_size(self) -> tuple:
54
"""Get spatial dimensions (height, width)."""
55
56
@property
57
def num_channels(self) -> int:
58
"""Get number of channels."""
59
60
@property
61
def image_size(self) -> tuple:
62
"""Get image size as (height, width)."""
63
```
64
65
### Video Tensors
66
67
Specialized tensors for temporal video data with frame sequence handling.
68
69
```python { .api }
70
class Video(TVTensor):
71
"""
72
Video tensor type for temporal data sequences.
73
74
Handles temporal dimension and preserves video-specific properties
75
through transformations. Maintains frame relationships and temporal
76
consistency.
77
78
Args:
79
data: Video data as tensor or array
80
dtype: Data type for the tensor
81
device: Device to place tensor on
82
requires_grad: Whether tensor requires gradients
83
84
Shape:
85
- (T, C, H, W): Single video with T frames, C channels, H height, W width
86
- (N, T, C, H, W): Batch of N videos
87
"""
88
89
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
90
91
@property
92
def num_frames(self) -> int:
93
"""Get number of frames."""
94
95
@property
96
def frame_size(self) -> tuple:
97
"""Get frame dimensions (height, width)."""
98
99
@property
100
def temporal_size(self) -> int:
101
"""Get temporal dimension size."""
102
```
103
104
### Bounding Box Tensors
105
106
Bounding box tensors with format awareness and coordinate system handling.
107
108
```python { .api }
109
class BoundingBoxes(TVTensor):
110
"""
111
Bounding box tensor with format and canvas size metadata.
112
113
Handles different bounding box formats and maintains coordinate
114
system constraints. Automatically handles transformations while
115
preserving box validity and format consistency.
116
117
Args:
118
data: Bounding box coordinates as tensor
119
format: Box format ('XYXY', 'XYWH', 'CXCYWH')
120
canvas_size: Image dimensions as (height, width)
121
dtype: Data type for coordinates
122
device: Device to place tensor on
123
requires_grad: Whether tensor requires gradients
124
125
Shape:
126
- (4,): Single bounding box [x1, y1, x2, y2] or format-specific
127
- (N, 4): N bounding boxes
128
"""
129
130
def __new__(cls, data, *, format: str, canvas_size: tuple, dtype=None, device=None, requires_grad=None): ...
131
132
@property
133
def format(self) -> str:
134
"""Get bounding box format ('XYXY', 'XYWH', 'CXCYWH')."""
135
136
@property
137
def canvas_size(self) -> tuple:
138
"""Get canvas dimensions (height, width)."""
139
140
@property
141
def clamping_mode(self) -> str:
142
"""Get clamping mode for out-of-bounds boxes."""
143
144
def clamp(self) -> 'BoundingBoxes':
145
"""Clamp boxes to canvas boundaries."""
146
147
def convert_format(self, format: str) -> 'BoundingBoxes':
148
"""Convert to different bounding box format."""
149
150
class BoundingBoxFormat:
151
"""Bounding box format constants and utilities."""
152
153
XYXY: str = "XYXY" # [x_min, y_min, x_max, y_max]
154
XYWH: str = "XYWH" # [x_min, y_min, width, height]
155
CXCYWH: str = "CXCYWH" # [center_x, center_y, width, height]
156
157
def is_rotated_bounding_format(format: str) -> bool:
158
"""
159
Check if bounding box format supports rotated boxes.
160
161
Args:
162
format (str): Bounding box format string
163
164
Returns:
165
bool: True if format supports rotation
166
"""
167
```
168
169
### Mask Tensors
170
171
Segmentation mask tensors for pixel-level annotations and predictions.
172
173
```python { .api }
174
class Mask(TVTensor):
175
"""
176
Segmentation mask tensor type for pixel-level annotations.
177
178
Handles boolean or integer masks while preserving spatial
179
relationships and mask-specific properties through transformations.
180
181
Args:
182
data: Mask data as tensor or array (boolean or integer)
183
dtype: Data type (typically torch.bool or torch.uint8)
184
device: Device to place tensor on
185
requires_grad: Whether tensor requires gradients
186
187
Shape:
188
- (H, W): Single binary mask
189
- (N, H, W): N masks (e.g., instance segmentation)
190
- (C, H, W): Multi-class segmentation mask
191
"""
192
193
def __new__(cls, data, *, dtype=None, device=None, requires_grad=None): ...
194
195
@property
196
def spatial_size(self) -> tuple:
197
"""Get spatial dimensions (height, width)."""
198
199
@property
200
def num_masks(self) -> int:
201
"""Get number of individual masks."""
202
203
def to_binary(self) -> 'Mask':
204
"""Convert to binary mask format."""
205
```
206
207
### Keypoint Tensors
208
209
Keypoint tensors for pose estimation and landmark detection tasks.
210
211
```python { .api }
212
class KeyPoints(TVTensor):
213
"""
214
Keypoint tensor with canvas size and connectivity information.
215
216
Handles keypoint coordinates with visibility information and
217
maintains spatial constraints through transformations.
218
219
Args:
220
data: Keypoint coordinates and visibility
221
canvas_size: Image dimensions as (height, width)
222
dtype: Data type for coordinates
223
device: Device to place tensor on
224
requires_grad: Whether tensor requires gradients
225
226
Shape:
227
- (K, 2): K keypoints with [x, y] coordinates
228
- (K, 3): K keypoints with [x, y, visibility]
229
- (N, K, 2): N instances with K keypoints each
230
- (N, K, 3): N instances with K keypoints and visibility
231
"""
232
233
def __new__(cls, data, *, canvas_size: tuple, dtype=None, device=None, requires_grad=None): ...
234
235
@property
236
def canvas_size(self) -> tuple:
237
"""Get canvas dimensions (height, width)."""
238
239
@property
240
def num_keypoints(self) -> int:
241
"""Get number of keypoints per instance."""
242
243
@property
244
def num_instances(self) -> int:
245
"""Get number of instances."""
246
247
def has_visibility(self) -> bool:
248
"""Check if keypoints have visibility information."""
249
250
def clamp(self) -> 'KeyPoints':
251
"""Clamp keypoints to canvas boundaries."""
252
```
253
254
### Tensor Wrapping and Utilities
255
256
Utilities for working with TV tensors and managing type consistency.
257
258
```python { .api }
259
def wrap(wrappee, *, like, **kwargs):
260
"""
261
Wrap tensor as same type as reference tensor.
262
263
Args:
264
wrappee: Tensor to wrap
265
like: Reference TV tensor to match type and metadata
266
**kwargs: Additional arguments for tensor creation
267
268
Returns:
269
TV tensor of same type as 'like' parameter
270
"""
271
272
def set_return_type(return_type: str):
273
"""
274
Set return type for tensor operations.
275
276
Args:
277
return_type (str): Type to return from operations
278
('TVTensor', 'Tensor', or 'auto')
279
"""
280
```
281
282
## Usage Examples
283
284
### Working with Image Tensors
285
286
```python
287
from torchvision.tv_tensors import Image
288
import torch
289
from PIL import Image as PILImage
290
291
# Create Image tensor from different sources
292
pil_image = PILImage.open('image.jpg')
293
image_tensor = Image(pil_image)
294
print(f"Image shape: {image_tensor.shape}")
295
print(f"Image size: {image_tensor.image_size}")
296
print(f"Channels: {image_tensor.num_channels}")
297
298
# Create from tensor data
299
tensor_data = torch.randint(0, 256, (3, 224, 224), dtype=torch.uint8)
300
image_tensor = Image(tensor_data)
301
302
# Image tensors preserve type through operations
303
scaled_image = image_tensor * 0.5
304
print(f"Scaled image type: {type(scaled_image)}") # Still Image type
305
306
# Batch of images
307
batch_data = torch.randint(0, 256, (8, 3, 224, 224), dtype=torch.uint8)
308
batch_images = Image(batch_data)
309
print(f"Batch shape: {batch_images.shape}")
310
```
311
312
### Working with Bounding Box Tensors
313
314
```python
315
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat
316
import torch
317
318
# Create bounding boxes in XYXY format
319
boxes_data = torch.tensor([
320
[10, 20, 100, 150],
321
[50, 30, 200, 180],
322
[75, 60, 150, 200]
323
], dtype=torch.float)
324
325
canvas_size = (240, 320) # (height, width)
326
boxes = BoundingBoxes(
327
boxes_data,
328
format=BoundingBoxFormat.XYXY,
329
canvas_size=canvas_size
330
)
331
332
print(f"Boxes format: {boxes.format}")
333
print(f"Canvas size: {boxes.canvas_size}")
334
335
# Convert between formats
336
boxes_xywh = boxes.convert_format(BoundingBoxFormat.XYWH)
337
print(f"Converted format: {boxes_xywh.format}")
338
print(f"XYWH boxes: {boxes_xywh}")
339
340
# Clamp boxes to canvas boundaries
341
# (useful after transformations that might move boxes out of bounds)
342
clamped_boxes = boxes.clamp()
343
344
# Create boxes in different format
345
center_boxes = BoundingBoxes(
346
torch.tensor([[50, 60, 80, 100]]), # [cx, cy, w, h]
347
format=BoundingBoxFormat.CXCYWH,
348
canvas_size=canvas_size
349
)
350
```
351
352
### Working with Mask Tensors
353
354
```python
355
from torchvision.tv_tensors import Mask
356
import torch
357
358
# Binary segmentation mask
359
binary_mask_data = torch.zeros(240, 320, dtype=torch.bool)
360
binary_mask_data[50:150, 60:160] = True # Create rectangular mask
361
binary_mask = Mask(binary_mask_data)
362
363
print(f"Binary mask shape: {binary_mask.shape}")
364
print(f"Spatial size: {binary_mask.spatial_size}")
365
366
# Multi-instance masks (e.g., for instance segmentation)
367
num_instances = 3
368
instance_masks_data = torch.zeros(num_instances, 240, 320, dtype=torch.bool)
369
instance_masks_data[0, 20:80, 30:90] = True # Instance 1
370
instance_masks_data[1, 100:160, 150:210] = True # Instance 2
371
instance_masks_data[2, 180:220, 50:150] = True # Instance 3
372
373
instance_masks = Mask(instance_masks_data)
374
print(f"Instance masks shape: {instance_masks.shape}")
375
print(f"Number of masks: {instance_masks.num_masks}")
376
377
# Integer-valued masks (e.g., semantic segmentation)
378
semantic_mask_data = torch.zeros(240, 320, dtype=torch.uint8)
379
semantic_mask_data[50:100, 50:100] = 1 # Class 1
380
semantic_mask_data[150:200, 150:200] = 2 # Class 2
381
382
semantic_mask = Mask(semantic_mask_data)
383
```
384
385
### Working with Keypoint Tensors
386
387
```python
388
from torchvision.tv_tensors import KeyPoints
389
import torch
390
391
# COCO-style human pose keypoints (17 keypoints)
392
# Format: [x, y, visibility] where visibility: 0=not labeled, 1=labeled but not visible, 2=labeled and visible
393
keypoints_data = torch.tensor([
394
[160, 80, 2], # nose
395
[155, 85, 2], # left_eye
396
[165, 85, 2], # right_eye
397
[150, 90, 2], # left_ear
398
[170, 90, 2], # right_ear
399
[140, 120, 2], # left_shoulder
400
[180, 120, 2], # right_shoulder
401
[130, 150, 2], # left_elbow
402
[190, 150, 1], # right_elbow (labeled but occluded)
403
[125, 180, 0], # left_wrist (not labeled)
404
[195, 180, 2], # right_wrist
405
[150, 200, 2], # left_hip
406
[170, 200, 2], # right_hip
407
[145, 240, 2], # left_knee
408
[175, 240, 2], # right_knee
409
[140, 280, 2], # left_ankle
410
[180, 280, 2], # right_ankle
411
], dtype=torch.float)
412
413
canvas_size = (320, 240)
414
keypoints = KeyPoints(keypoints_data, canvas_size=canvas_size)
415
416
print(f"Keypoints shape: {keypoints.shape}")
417
print(f"Number of keypoints: {keypoints.num_keypoints}")
418
print(f"Has visibility: {keypoints.has_visibility()}")
419
print(f"Canvas size: {keypoints.canvas_size}")
420
421
# Multiple person keypoints
422
batch_keypoints_data = torch.randn(5, 17, 3) # 5 people, 17 keypoints, [x,y,vis]
423
batch_keypoints = KeyPoints(batch_keypoints_data, canvas_size=canvas_size)
424
print(f"Batch keypoints shape: {batch_keypoints.shape}")
425
print(f"Number of instances: {batch_keypoints.num_instances}")
426
427
# Clamp keypoints to image boundaries
428
clamped_keypoints = keypoints.clamp()
429
```
430
431
### Working with Video Tensors
432
433
```python
434
from torchvision.tv_tensors import Video
435
import torch
436
437
# Create video tensor (16 frames, 3 channels, 224x224)
438
video_data = torch.randint(0, 256, (16, 3, 224, 224), dtype=torch.uint8)
439
video = Video(video_data)
440
441
print(f"Video shape: {video.shape}")
442
print(f"Number of frames: {video.num_frames}")
443
print(f"Frame size: {video.frame_size}")
444
print(f"Temporal size: {video.temporal_size}")
445
446
# Batch of videos
447
batch_video_data = torch.randint(0, 256, (4, 16, 3, 224, 224), dtype=torch.uint8)
448
batch_videos = Video(batch_video_data)
449
print(f"Batch videos shape: {batch_videos.shape}")
450
451
# Video tensors maintain type through operations
452
downsampled_video = video[:8] # Take first 8 frames
453
print(f"Downsampled video type: {type(downsampled_video)}")
454
```
455
456
### Using TV Tensors with Transforms
457
458
```python
459
from torchvision.tv_tensors import Image, BoundingBoxes, BoundingBoxFormat
460
from torchvision.transforms import v2
461
import torch
462
463
# Create sample data
464
image = Image(torch.randint(0, 256, (3, 480, 640), dtype=torch.uint8))
465
boxes = BoundingBoxes(
466
torch.tensor([[50, 50, 200, 200], [300, 150, 450, 350]]),
467
format=BoundingBoxFormat.XYXY,
468
canvas_size=(480, 640)
469
)
470
471
print("Before transform:")
472
print(f"Image shape: {image.shape}")
473
print(f"Boxes: {boxes}")
474
print(f"Boxes format: {boxes.format}")
475
476
# Apply transforms that work with multiple tensor types
477
transform = v2.Compose([
478
v2.RandomHorizontalFlip(p=1.0), # Always flip for demonstration
479
v2.Resize((224, 224)),
480
v2.ToDtype(torch.float32, scale=True)
481
])
482
483
# Transform both image and boxes together
484
transformed_image, transformed_boxes = transform(image, boxes)
485
486
print("\nAfter transform:")
487
print(f"Image shape: {transformed_image.shape}")
488
print(f"Image type: {type(transformed_image)}")
489
print(f"Boxes: {transformed_boxes}")
490
print(f"Boxes type: {type(transformed_boxes)}")
491
print(f"Boxes format: {transformed_boxes.format}")
492
print(f"New canvas size: {transformed_boxes.canvas_size}")
493
```
494
495
### Custom Operations with TV Tensors
496
497
```python
498
from torchvision.tv_tensors import Image, wrap
499
import torch
500
501
def custom_image_operation(img):
502
"""
503
Custom operation that preserves TV tensor type.
504
"""
505
# Perform some operation on the underlying tensor
506
processed = img * 0.8 + 0.1 # Adjust brightness
507
508
# Wrap result to maintain TV tensor type and metadata
509
return wrap(processed, like=img)
510
511
def batch_process_images(images):
512
"""
513
Process batch of images while maintaining types.
514
"""
515
results = []
516
for img in images:
517
processed = custom_image_operation(img)
518
results.append(processed)
519
520
return torch.stack(results)
521
522
# Test custom operations
523
image = Image(torch.rand(3, 224, 224))
524
processed_image = custom_image_operation(image)
525
526
print(f"Original type: {type(image)}")
527
print(f"Processed type: {type(processed_image)}")
528
529
# Works with batches too
530
batch_images = [Image(torch.rand(3, 224, 224)) for _ in range(4)]
531
batch_result = batch_process_images(batch_images)
532
print(f"Batch result shape: {batch_result.shape}")
533
print(f"Batch result type: {type(batch_result)}")
534
```
535
536
### Type Consistency in Pipelines
537
538
```python
539
from torchvision.tv_tensors import Image, BoundingBoxes, Mask, BoundingBoxFormat
540
from torchvision.transforms import v2
541
import torch
542
543
def detection_pipeline():
544
"""
545
Example object detection data pipeline using TV tensors.
546
"""
547
# Simulate loading detection data
548
image = Image(torch.randint(0, 256, (3, 480, 640), dtype=torch.uint8))
549
550
boxes = BoundingBoxes(
551
torch.tensor([[100, 100, 300, 250], [200, 150, 400, 350]]),
552
format=BoundingBoxFormat.XYXY,
553
canvas_size=(480, 640)
554
)
555
556
# Instance masks for each detection
557
masks_data = torch.zeros(2, 480, 640, dtype=torch.bool)
558
masks_data[0, 100:250, 100:300] = True
559
masks_data[1, 150:350, 200:400] = True
560
masks = Mask(masks_data)
561
562
# Labels for each detection
563
labels = torch.tensor([1, 2]) # Class IDs
564
565
print("Original data:")
566
print(f"Image: {image.shape}, {type(image)}")
567
print(f"Boxes: {boxes.shape}, {type(boxes)}")
568
print(f"Masks: {masks.shape}, {type(masks)}")
569
570
# Data augmentation pipeline
571
transform = v2.Compose([
572
v2.RandomHorizontalFlip(p=0.5),
573
v2.RandomResizedCrop((416, 416), scale=(0.8, 1.0)),
574
v2.ColorJitter(brightness=0.2, contrast=0.2),
575
v2.ToDtype(torch.float32, scale=True),
576
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
577
])
578
579
# Apply transforms (all types are handled automatically)
580
aug_image, aug_boxes, aug_masks = transform(image, boxes, masks)
581
582
print("\nAfter augmentation:")
583
print(f"Image: {aug_image.shape}, {type(aug_image)}")
584
print(f"Boxes: {aug_boxes.shape}, {type(aug_boxes)}")
585
print(f"Masks: {aug_masks.shape}, {type(aug_masks)}")
586
print(f"Canvas size updated: {aug_boxes.canvas_size}")
587
588
return aug_image, aug_boxes, aug_masks, labels
589
590
# Run detection pipeline
591
processed_data = detection_pipeline()
592
```