or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

augmentation.mdcomposition.mdcore-data-structures.mddata-loading.mddatasets.mdindex.mdpreprocessing.mdsampling.mdutilities.md
tile.json

tessl/pypi-torchio

Tools for medical image processing with PyTorch

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/torchio@0.20.x

To install, run

npx @tessl/cli install tessl/pypi-torchio@0.20.0

index.mddocs/

TorchIO

TorchIO is a comprehensive Python package designed for efficient processing of 3D medical images in deep learning applications built with PyTorch. It provides a complete toolkit for reading, preprocessing, sampling, augmenting, and writing medical images, featuring both standard computer vision transforms and domain-specific medical imaging transforms that simulate realistic artifacts such as MRI magnetic field inhomogeneity and k-space motion artifacts.

Package Information

  • Package Name: torchio
  • Language: Python
  • Installation: pip install torchio

Core Imports

import torchio as tio

For specific components:

from torchio import Subject, ScalarImage, LabelMap
from torchio import SubjectsDataset, SubjectsLoader
from torchio import Compose, RandomFlip, RandomAffine
from torchio.data import UniformSampler, GridSampler

Basic Usage

import torchio as tio

# Create a subject with medical images
subject = tio.Subject(
    t1=tio.ScalarImage('t1.nii.gz'),
    t2=tio.ScalarImage('t2.nii.gz'),
    seg=tio.LabelMap('seg.nii.gz'),
    age=45,
    name='Subject_001'
)

# Define preprocessing and augmentation transforms
transform = tio.Compose([
    tio.ToCanonical(),                    # Reorient to canonical orientation
    tio.Resample(1),                      # Resample to 1mm isotropic
    tio.CropOrPad((128, 128, 64)),       # Crop or pad to target shape
    tio.ZNormalization(),                 # Z-score normalization
    tio.RandomFlip(),                     # Random flipping
    tio.RandomAffine(),                   # Random affine transformation
    tio.RandomNoise(std=0.1),            # Add random noise
])

# Apply transforms to subject
transformed_subject = transform(subject)

# Create dataset for training
subjects = [subject1, subject2, subject3, ...]  # List of subjects
dataset = tio.SubjectsDataset(subjects, transform=transform)

# Create data loader for patch-based training
patch_size = 64
samples_per_volume = 10
sampler = tio.data.UniformSampler(patch_size)
patches_queue = tio.Queue(
    subjects_dataset=dataset,
    max_length=100,
    samples_per_volume=samples_per_volume,
    sampler=sampler,
)
loader = torch.utils.data.DataLoader(patches_queue, batch_size=4)

# Training loop
for batch in loader:
    # batch contains patches ready for training
    inputs = batch['t1'][tio.DATA]
    targets = batch['seg'][tio.DATA]
    # ... train your model

Architecture

TorchIO follows a hierarchical design optimized for medical image processing workflows:

  • Subject: Dictionary-like container storing multiple medical images and metadata for a single patient/case
  • Image Types: Specialized classes for different image modalities (ScalarImage for intensity images, LabelMap for segmentations)
  • Transform System: Hierarchical transform classes (Transform → SpatialTransform/IntensityTransform) with history tracking
  • Sampling Strategies: Flexible patch sampling for 3D volumes (uniform, weighted, label-based, grid-based)
  • Data Pipeline: PyTorch-compatible datasets and loaders optimized for medical imaging workflows

Capabilities

Core Data Structures

Essential data structures for handling medical images, including the Subject container, various Image types, and dataset management for organizing multiple subjects.

class Subject(dict):
    def __init__(self, *args, **kwargs: dict[str, Any]): ...
    def get_images(self, intensity_only: bool = True) -> list[Image]: ...
    def check_consistent_spatial_shape(self): ...

class Image:
    def __init__(self, path: TypePath, type: str = None, **kwargs): ...
    @property
    def data(self) -> torch.Tensor: ...
    @property
    def affine(self) -> np.ndarray: ...

class ScalarImage(Image):
    """Represents intensity/scalar medical images (e.g., MRI, CT scans)"""

class LabelMap(ScalarImage):
    """Represents segmentation/label images"""

class SubjectsDataset(torch.utils.data.Dataset):
    def __init__(self, subjects: Sequence[Subject], transform: Transform = None): ...

Core Data Structures

Data Loading and Management

PyTorch-compatible data loading utilities optimized for medical imaging, including specialized data loaders, queues for patch-based training, and efficient batch processing.

class SubjectsLoader(torch.utils.data.DataLoader):
    """PyTorch DataLoader wrapper optimized for medical image subjects"""

class Queue(torch.utils.data.Dataset):
    def __init__(
        self,
        subjects_dataset: SubjectsDataset,
        max_length: int,
        samples_per_volume: int,
        sampler: PatchSampler,
        **kwargs
    ): ...

Data Loading and Management

Patch Sampling Strategies

Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference.

class PatchSampler:
    """Base class for patch sampling strategies"""
    def __call__(self, sample: Subject) -> dict: ...

class UniformSampler(PatchSampler):
    def __init__(self, patch_size: TypeSpatialShape): ...

class WeightedSampler(PatchSampler):
    def __init__(self, patch_size: TypeSpatialShape, probability_map: str): ...

class LabelSampler(WeightedSampler):
    def __init__(self, patch_size: TypeSpatialShape, label_name: str, label_probabilities: dict = None): ...

class GridSampler(PatchSampler):
    def __init__(self, patch_size: TypeSpatialShape, patch_overlap: TypeSpatialShape = 0): ...

class GridAggregator:
    def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'): ...

Patch Sampling Strategies

Preprocessing Transforms

Comprehensive preprocessing transforms for medical images including spatial transformations (resampling, cropping, padding), intensity normalization (z-score, histogram standardization), and specialized medical imaging preprocessing.

# Spatial preprocessing
class Resample(SpatialTransform):
    def __init__(self, target: TypeSpacing, image_interpolation: str = 'linear'): ...

class CropOrPad(SpatialTransform):
    def __init__(self, target_shape: TypeSpatialShape, padding_mode: str = 'constant'): ...

class ToCanonical(SpatialTransform):
    """Reorients images to canonical orientation (RAS+)"""

# Intensity preprocessing
class ZNormalization(IntensityTransform):
    def __init__(self, masking_method: str = None): ...

class RescaleIntensity(IntensityTransform):
    def __init__(self, out_min_max: tuple[float, float] = (0, 1)): ...

class HistogramStandardization(IntensityTransform):
    def __init__(self, landmarks: dict): ...

Preprocessing Transforms

Augmentation Transforms

Extensive augmentation transforms including spatial augmentations (affine, elastic deformation, flipping) and intensity augmentations with medical imaging-specific artifacts (motion, ghosting, bias field, spike artifacts).

# Spatial augmentation
class RandomFlip(SpatialTransform):
    def __init__(self, axes: TypeTuple = (0,), flip_probability: float = 0.5): ...

class RandomAffine(SpatialTransform):
    def __init__(
        self,
        scales: TypeRangeFloat = None,
        degrees: TypeRangeFloat = None,
        translation: TypeRangeFloat = None,
        **kwargs
    ): ...

class RandomElasticDeformation(SpatialTransform):
    def __init__(self, num_control_points: TypeTuple = 7, max_displacement: TypeRangeFloat = 7.5): ...

# Medical imaging specific augmentation
class RandomMotion(IntensityTransform):
    def __init__(self, degrees: TypeRangeFloat = 10, translation: TypeRangeFloat = 10): ...

class RandomBiasField(IntensityTransform):
    def __init__(self, coefficients: TypeRangeFloat = 0.5): ...

class RandomGhosting(IntensityTransform):
    def __init__(self, num_ghosts: tuple[int, int] = (4, 10), axes: tuple[int, ...] = (0, 1, 2)): ...

Augmentation Transforms

Transform Composition

Tools for combining and organizing transforms into pipelines, including sequential composition, random selection from transform groups, and custom lambda transforms.

class Compose(Transform):
    def __init__(self, transforms: Sequence[Transform]): ...

class OneOf(Transform):
    def __init__(self, transforms: dict[Transform, float]): ...

class Lambda(Transform):
    def __init__(self, function: Callable, types_to_apply: tuple[type, ...] = None): ...

Transform Composition

Medical Image Datasets

Pre-built datasets for common medical imaging research, including brain atlases, public medical imaging challenges, and synthetic datasets for testing and development.

# Brain atlases and templates
class Colin27(Subject):
    """Colin27 brain template"""

class ICBM2009CNonlinearSymmetric(Subject):
    """ICBM 2009c nonlinear symmetric brain template"""

# Public datasets
class IXI(SubjectsDataset):
    """IXI dataset - brain MR images from healthy subjects"""

class RSNAMICCAI(SubjectsDataset):
    """RSNA-MICCAI Brain Tumor Radiogenomic Classification dataset"""

# MedMNIST 3D datasets
class OrganMNIST3D(SubjectsDataset):
    """3D organ segmentation dataset"""

Medical Image Datasets

Utilities and Constants

Helper functions, type definitions, and constants for medical image processing, including file I/O utilities, type conversion functions, and medical imaging constants.

# Utility functions
def to_tuple(value: Any, length: int = 1) -> tuple[TypeNumber, ...]: ...
def apply_transform_to_file(transform: Transform, input_path: TypePath, output_path: TypePath): ...
def get_torchio_cache_dir() -> Path: ...

# Type definitions
TypePath = Union[str, Path]
TypeSpatialShape = Union[int, tuple[int, int, int]]
TypeSpacing = Union[float, tuple[float, float, float]]

# Constants
INTENSITY = 'intensity'
LABEL = 'label'
DATA = 'data'
AFFINE = 'affine'

Utilities and Constants

Command Line Interface

CLI tools for common medical image processing operations, providing convenient command-line access to TorchIO functionality.

# Available CLI commands:
# tiohd - Print image information and optionally display
# tiotr/torchio-transform - Apply transforms to images from command line

tiohd - Print image header information and optionally visualize:

  • Options: --plot/-p (plot using matplotlib), --show/-s (show in external viewer), --label/-l (treat as label image), --load (load data for memory info)
  • Usage: tiohd input.nii.gz --plot --show

tiotr/torchio-transform - Apply transforms to images:

  • Arguments: input_path, transform_name, output_path
  • Options: --kwargs/-k (transform parameters), --imclass/-c (image class), --seed/-s (random seed), --verbose/-v
  • Usage: tiotr input.nii.gz RandomFlip output.nii.gz --kwargs "axes=(0,1)"

Types

# Basic types
TypePath = Union[str, Path]
TypeNumber = Union[int, float]
TypeData = Union[torch.Tensor, np.ndarray]
TypeDataAffine = tuple[torch.Tensor, np.ndarray]
TypeSlice = Union[int, slice]
TypeKeys = Optional[Sequence[str]]

# Numeric tuple types
TypeDoubletInt = tuple[int, int]
TypeTripletInt = tuple[int, int, int]
TypeQuartetInt = tuple[int, int, int, int]
TypeSextetInt = tuple[int, int, int, int, int, int]

TypeDoubleFloat = tuple[float, float]
TypeTripletFloat = tuple[float, float, float]
TypeQuartetFloat = tuple[float, float, float, float]
TypeSextetFloat = tuple[float, float, float, float, float, float]

# Geometric types
TypeTuple = Union[int, TypeTripletInt]
TypeRangeInt = Union[int, TypeDoubletInt]
TypeSpatialShape = Union[int, TypeTripletInt]
TypeSpacing = Union[float, TypeTripletFloat]
TypeRangeFloat = Union[float, TypeDoubleFloat]

# Transform types
TypeCallable = Callable[[torch.Tensor], torch.Tensor]

# Direction matrix types
TypeDirection2D = TypeQuartetFloat
TypeDirection3D = tuple[float, float, float, float, float, float, float, float, float]
TypeDirection = Union[TypeDirection2D, TypeDirection3D]

# Image types
class Image:
    """Base class for medical images"""
    
class ScalarImage(Image):
    """Intensity/scalar medical images (MRI, CT, etc.)"""
    
class LabelMap(ScalarImage):
    """Segmentation/label images"""

class Subject(dict):
    """Container for multiple medical images and metadata"""

# Transform hierarchy
class Transform:
    """Base class for all transforms"""
    
class SpatialTransform(Transform):
    """Base class for spatial transformations"""
    
class IntensityTransform(Transform):
    """Base class for intensity transformations"""