Tools for medical image processing with PyTorch
Essential data structures for handling medical images in TorchIO, providing the foundation for medical image processing workflows. These classes handle image loading, metadata management, and provide the base for all TorchIO operations.
A dictionary-like container that stores multiple medical images and metadata for a single patient or case. The Subject class is the primary data structure in TorchIO, organizing related images (T1, T2, segmentations, etc.) and their associated information.
class Subject(dict):
"""
Dictionary-like container for storing medical images and metadata.
Parameters:
- *args: If provided, a dictionary of items
- **kwargs: Items that will be added to the subject sample
"""
def __init__(self, *args, **kwargs: dict[str, Any]): ...
def get_images(self, intensity_only: bool = True) -> list[Image]:
"""
Get list of images in the subject.
Parameters:
- intensity_only: If True, only return intensity images (not label maps)
Returns:
List of Image objects
"""
def check_consistent_spatial_shape(self):
"""Check all images have the same spatial shape"""
def check_consistent_orientation(self):
"""Check all images have the same orientation"""
def check_consistent_space(self):
"""Check all images are in the same physical space"""
@property
def shape(self) -> tuple[int, int, int, int]:
"""Return shape of first image in subject (channels, depth, height, width)"""
@property
def spatial_shape(self) -> tuple[int, int, int]:
"""Return spatial shape (depth, height, width)"""
@property
def spacing(self) -> tuple[float, float, float]:
"""Return voxel spacing in mm"""
def apply_transform(self, transform: 'Transform') -> 'Subject':
"""Apply a transform to the subject"""
def plot(self, **kwargs):
"""Plot all images in the subject"""
def load(self):
"""Load all images in the subject into memory"""Usage example:
import torchio as tio
# Create a subject with multiple images and metadata
subject = tio.Subject(
t1=tio.ScalarImage('path/to/t1.nii.gz'),
t2=tio.ScalarImage('path/to/t2.nii.gz'),
flair=tio.ScalarImage('path/to/flair.nii.gz'),
seg=tio.LabelMap('path/to/segmentation.nii.gz'),
age=67,
sex='M',
diagnosis='healthy'
)
# Access images and metadata
t1_image = subject['t1']
age = subject['age']
# Get all intensity images
intensity_images = subject.get_images(intensity_only=True)
# Check spatial consistency
subject.check_consistent_spatial_shape()Abstract base class for all medical images in TorchIO, providing common functionality for image loading, data access, and spatial information management.
class Image:
"""
Base class for medical images.
Parameters:
- path: Path to the image file
- type: Type of image ('intensity' or 'label')
- **kwargs: Additional keyword arguments
"""
def __init__(self, path: TypePath, type: str = None, **kwargs): ...
@property
def data(self) -> torch.Tensor:
"""Image data as PyTorch tensor with shape (C, D, H, W)"""
@property
def affine(self) -> np.ndarray:
"""4x4 affine transformation matrix from voxel to world coordinates"""
@property
def shape(self) -> tuple[int, int, int, int]:
"""Image shape (channels, depth, height, width)"""
@property
def spatial_shape(self) -> tuple[int, int, int]:
"""Spatial shape (depth, height, width)"""
@property
def spacing(self) -> tuple[float, float, float]:
"""Voxel spacing in mm (depth, height, width)"""
@property
def origin(self) -> tuple[float, float, float]:
"""Image origin in world coordinates"""
@property
def orientation(self) -> tuple[str, str, str]:
"""Image orientation (e.g., ('L', 'A', 'S'))"""
def get_center(self, lps: bool = False) -> tuple[float, float, float]:
"""Get center coordinates of the image"""
def save(self, path: TypePath):
"""Save image to file"""
def plot(self, **kwargs):
"""Plot image using matplotlib"""
def show(self, **kwargs):
"""Show image using external viewer"""
def as_pil(self) -> 'PIL.Image':
"""Convert to PIL Image (for 2D images)"""
def as_sitk(self) -> 'sitk.Image':
"""Convert to SimpleITK Image"""
def to_gif(self, output_path: TypePath, **kwargs):
"""Save image as animated GIF"""
def __getitem__(self, item) -> 'Image':
"""Support indexing and slicing"""Represents intensity or scalar medical images such as MRI, CT scans, PET scans, and other quantitative imaging modalities.
class ScalarImage(Image):
"""
Represents intensity/scalar medical images (e.g., MRI, CT scans).
Inherits all functionality from Image class.
Default type: 'intensity'
"""
def __init__(self, path: TypePath, **kwargs): ...Usage example:
# Load different types of scalar images
t1_image = tio.ScalarImage('t1_weighted.nii.gz')
ct_image = tio.ScalarImage('ct_scan.nii.gz')
pet_image = tio.ScalarImage('pet_scan.nii.gz')
# Access image properties
print(f"T1 shape: {t1_image.shape}")
print(f"T1 spacing: {t1_image.spacing}")
print(f"T1 origin: {t1_image.origin}")
# Access image data
t1_data = t1_image.data # torch.Tensor with shape (1, D, H, W)Represents segmentation or label images where each voxel contains a discrete label value indicating tissue type, anatomical structure, or pathological region.
class LabelMap(ScalarImage):
"""
Represents segmentation/label images.
Inherits from ScalarImage but optimized for discrete label values.
Default type: 'label'
"""
def __init__(self, path: TypePath, **kwargs): ...
def get_unique_labels(self) -> list[int]:
"""Get list of unique label values in the image"""
def get_label_statistics(self) -> dict:
"""Get statistics for each label (volume, centroid, etc.)"""Usage example:
# Load segmentation image
segmentation = tio.LabelMap('brain_segmentation.nii.gz')
# Get unique labels
labels = segmentation.get_unique_labels()
print(f"Segmentation contains labels: {labels}")
# Access label statistics
stats = segmentation.get_label_statistics()PyTorch-compatible dataset class for managing collections of subjects with optional transforms.
class SubjectsDataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for loading multiple subjects with transforms.
Parameters:
- subjects: Sequence of Subject instances
- transform: Optional transform to apply to each subject
- load_getitem: If True, load images when __getitem__ is called
"""
def __init__(
self,
subjects: Sequence[Subject],
transform: Transform = None,
load_getitem: bool = True
): ...
def __len__(self) -> int:
"""Return number of subjects"""
def __getitem__(self, index: int) -> Subject:
"""Get subject at index, applying transform if specified"""
def dry_iter(self):
"""Iterate through subjects without loading image data"""
def set_transform(self, transform: Transform):
"""Set or update the transform"""Usage example:
# Create subjects
subjects = [
tio.Subject(
t1=tio.ScalarImage(f'subject_{i}_t1.nii.gz'),
seg=tio.LabelMap(f'subject_{i}_seg.nii.gz'),
age=ages[i]
)
for i in range(100)
]
# Create dataset with preprocessing transform
preprocessing = tio.Compose([
tio.ToCanonical(),
tio.Resample(1), # 1mm isotropic
tio.CropOrPad((128, 128, 64)),
tio.ZNormalization(),
])
dataset = tio.SubjectsDataset(subjects, transform=preprocessing)
# Use with PyTorch DataLoader
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
for batch in loader:
# Process batch of subjects
passInstall with Tessl CLI
npx tessl i tessl/pypi-torchio