Tools for medical image processing with PyTorch
npx @tessl/cli install tessl/pypi-torchio@0.20.0TorchIO is a comprehensive Python package designed for efficient processing of 3D medical images in deep learning applications built with PyTorch. It provides a complete toolkit for reading, preprocessing, sampling, augmenting, and writing medical images, featuring both standard computer vision transforms and domain-specific medical imaging transforms that simulate realistic artifacts such as MRI magnetic field inhomogeneity and k-space motion artifacts.
pip install torchioimport torchio as tioFor specific components:
from torchio import Subject, ScalarImage, LabelMap
from torchio import SubjectsDataset, SubjectsLoader
from torchio import Compose, RandomFlip, RandomAffine
from torchio.data import UniformSampler, GridSamplerimport torchio as tio
# Create a subject with medical images
subject = tio.Subject(
t1=tio.ScalarImage('t1.nii.gz'),
t2=tio.ScalarImage('t2.nii.gz'),
seg=tio.LabelMap('seg.nii.gz'),
age=45,
name='Subject_001'
)
# Define preprocessing and augmentation transforms
transform = tio.Compose([
tio.ToCanonical(), # Reorient to canonical orientation
tio.Resample(1), # Resample to 1mm isotropic
tio.CropOrPad((128, 128, 64)), # Crop or pad to target shape
tio.ZNormalization(), # Z-score normalization
tio.RandomFlip(), # Random flipping
tio.RandomAffine(), # Random affine transformation
tio.RandomNoise(std=0.1), # Add random noise
])
# Apply transforms to subject
transformed_subject = transform(subject)
# Create dataset for training
subjects = [subject1, subject2, subject3, ...] # List of subjects
dataset = tio.SubjectsDataset(subjects, transform=transform)
# Create data loader for patch-based training
patch_size = 64
samples_per_volume = 10
sampler = tio.data.UniformSampler(patch_size)
patches_queue = tio.Queue(
subjects_dataset=dataset,
max_length=100,
samples_per_volume=samples_per_volume,
sampler=sampler,
)
loader = torch.utils.data.DataLoader(patches_queue, batch_size=4)
# Training loop
for batch in loader:
# batch contains patches ready for training
inputs = batch['t1'][tio.DATA]
targets = batch['seg'][tio.DATA]
# ... train your modelTorchIO follows a hierarchical design optimized for medical image processing workflows:
Essential data structures for handling medical images, including the Subject container, various Image types, and dataset management for organizing multiple subjects.
class Subject(dict):
def __init__(self, *args, **kwargs: dict[str, Any]): ...
def get_images(self, intensity_only: bool = True) -> list[Image]: ...
def check_consistent_spatial_shape(self): ...
class Image:
def __init__(self, path: TypePath, type: str = None, **kwargs): ...
@property
def data(self) -> torch.Tensor: ...
@property
def affine(self) -> np.ndarray: ...
class ScalarImage(Image):
"""Represents intensity/scalar medical images (e.g., MRI, CT scans)"""
class LabelMap(ScalarImage):
"""Represents segmentation/label images"""
class SubjectsDataset(torch.utils.data.Dataset):
def __init__(self, subjects: Sequence[Subject], transform: Transform = None): ...PyTorch-compatible data loading utilities optimized for medical imaging, including specialized data loaders, queues for patch-based training, and efficient batch processing.
class SubjectsLoader(torch.utils.data.DataLoader):
"""PyTorch DataLoader wrapper optimized for medical image subjects"""
class Queue(torch.utils.data.Dataset):
def __init__(
self,
subjects_dataset: SubjectsDataset,
max_length: int,
samples_per_volume: int,
sampler: PatchSampler,
**kwargs
): ...Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference.
class PatchSampler:
"""Base class for patch sampling strategies"""
def __call__(self, sample: Subject) -> dict: ...
class UniformSampler(PatchSampler):
def __init__(self, patch_size: TypeSpatialShape): ...
class WeightedSampler(PatchSampler):
def __init__(self, patch_size: TypeSpatialShape, probability_map: str): ...
class LabelSampler(WeightedSampler):
def __init__(self, patch_size: TypeSpatialShape, label_name: str, label_probabilities: dict = None): ...
class GridSampler(PatchSampler):
def __init__(self, patch_size: TypeSpatialShape, patch_overlap: TypeSpatialShape = 0): ...
class GridAggregator:
def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'): ...Comprehensive preprocessing transforms for medical images including spatial transformations (resampling, cropping, padding), intensity normalization (z-score, histogram standardization), and specialized medical imaging preprocessing.
# Spatial preprocessing
class Resample(SpatialTransform):
def __init__(self, target: TypeSpacing, image_interpolation: str = 'linear'): ...
class CropOrPad(SpatialTransform):
def __init__(self, target_shape: TypeSpatialShape, padding_mode: str = 'constant'): ...
class ToCanonical(SpatialTransform):
"""Reorients images to canonical orientation (RAS+)"""
# Intensity preprocessing
class ZNormalization(IntensityTransform):
def __init__(self, masking_method: str = None): ...
class RescaleIntensity(IntensityTransform):
def __init__(self, out_min_max: tuple[float, float] = (0, 1)): ...
class HistogramStandardization(IntensityTransform):
def __init__(self, landmarks: dict): ...Extensive augmentation transforms including spatial augmentations (affine, elastic deformation, flipping) and intensity augmentations with medical imaging-specific artifacts (motion, ghosting, bias field, spike artifacts).
# Spatial augmentation
class RandomFlip(SpatialTransform):
def __init__(self, axes: TypeTuple = (0,), flip_probability: float = 0.5): ...
class RandomAffine(SpatialTransform):
def __init__(
self,
scales: TypeRangeFloat = None,
degrees: TypeRangeFloat = None,
translation: TypeRangeFloat = None,
**kwargs
): ...
class RandomElasticDeformation(SpatialTransform):
def __init__(self, num_control_points: TypeTuple = 7, max_displacement: TypeRangeFloat = 7.5): ...
# Medical imaging specific augmentation
class RandomMotion(IntensityTransform):
def __init__(self, degrees: TypeRangeFloat = 10, translation: TypeRangeFloat = 10): ...
class RandomBiasField(IntensityTransform):
def __init__(self, coefficients: TypeRangeFloat = 0.5): ...
class RandomGhosting(IntensityTransform):
def __init__(self, num_ghosts: tuple[int, int] = (4, 10), axes: tuple[int, ...] = (0, 1, 2)): ...Tools for combining and organizing transforms into pipelines, including sequential composition, random selection from transform groups, and custom lambda transforms.
class Compose(Transform):
def __init__(self, transforms: Sequence[Transform]): ...
class OneOf(Transform):
def __init__(self, transforms: dict[Transform, float]): ...
class Lambda(Transform):
def __init__(self, function: Callable, types_to_apply: tuple[type, ...] = None): ...Pre-built datasets for common medical imaging research, including brain atlases, public medical imaging challenges, and synthetic datasets for testing and development.
# Brain atlases and templates
class Colin27(Subject):
"""Colin27 brain template"""
class ICBM2009CNonlinearSymmetric(Subject):
"""ICBM 2009c nonlinear symmetric brain template"""
# Public datasets
class IXI(SubjectsDataset):
"""IXI dataset - brain MR images from healthy subjects"""
class RSNAMICCAI(SubjectsDataset):
"""RSNA-MICCAI Brain Tumor Radiogenomic Classification dataset"""
# MedMNIST 3D datasets
class OrganMNIST3D(SubjectsDataset):
"""3D organ segmentation dataset"""Helper functions, type definitions, and constants for medical image processing, including file I/O utilities, type conversion functions, and medical imaging constants.
# Utility functions
def to_tuple(value: Any, length: int = 1) -> tuple[TypeNumber, ...]: ...
def apply_transform_to_file(transform: Transform, input_path: TypePath, output_path: TypePath): ...
def get_torchio_cache_dir() -> Path: ...
# Type definitions
TypePath = Union[str, Path]
TypeSpatialShape = Union[int, tuple[int, int, int]]
TypeSpacing = Union[float, tuple[float, float, float]]
# Constants
INTENSITY = 'intensity'
LABEL = 'label'
DATA = 'data'
AFFINE = 'affine'CLI tools for common medical image processing operations, providing convenient command-line access to TorchIO functionality.
# Available CLI commands:
# tiohd - Print image information and optionally display
# tiotr/torchio-transform - Apply transforms to images from command linetiohd - Print image header information and optionally visualize:
--plot/-p (plot using matplotlib), --show/-s (show in external viewer), --label/-l (treat as label image), --load (load data for memory info)tiohd input.nii.gz --plot --showtiotr/torchio-transform - Apply transforms to images:
input_path, transform_name, output_path--kwargs/-k (transform parameters), --imclass/-c (image class), --seed/-s (random seed), --verbose/-vtiotr input.nii.gz RandomFlip output.nii.gz --kwargs "axes=(0,1)"# Basic types
TypePath = Union[str, Path]
TypeNumber = Union[int, float]
TypeData = Union[torch.Tensor, np.ndarray]
TypeDataAffine = tuple[torch.Tensor, np.ndarray]
TypeSlice = Union[int, slice]
TypeKeys = Optional[Sequence[str]]
# Numeric tuple types
TypeDoubletInt = tuple[int, int]
TypeTripletInt = tuple[int, int, int]
TypeQuartetInt = tuple[int, int, int, int]
TypeSextetInt = tuple[int, int, int, int, int, int]
TypeDoubleFloat = tuple[float, float]
TypeTripletFloat = tuple[float, float, float]
TypeQuartetFloat = tuple[float, float, float, float]
TypeSextetFloat = tuple[float, float, float, float, float, float]
# Geometric types
TypeTuple = Union[int, TypeTripletInt]
TypeRangeInt = Union[int, TypeDoubletInt]
TypeSpatialShape = Union[int, TypeTripletInt]
TypeSpacing = Union[float, TypeTripletFloat]
TypeRangeFloat = Union[float, TypeDoubleFloat]
# Transform types
TypeCallable = Callable[[torch.Tensor], torch.Tensor]
# Direction matrix types
TypeDirection2D = TypeQuartetFloat
TypeDirection3D = tuple[float, float, float, float, float, float, float, float, float]
TypeDirection = Union[TypeDirection2D, TypeDirection3D]
# Image types
class Image:
"""Base class for medical images"""
class ScalarImage(Image):
"""Intensity/scalar medical images (MRI, CT, etc.)"""
class LabelMap(ScalarImage):
"""Segmentation/label images"""
class Subject(dict):
"""Container for multiple medical images and metadata"""
# Transform hierarchy
class Transform:
"""Base class for all transforms"""
class SpatialTransform(Transform):
"""Base class for spatial transformations"""
class IntensityTransform(Transform):
"""Base class for intensity transformations"""