Tools for medical image processing with PyTorch
Helper functions, type definitions, and constants for medical image processing, including file I/O utilities, type conversion functions, command-line tools, and medical imaging constants that support TorchIO's core functionality.
Utility functions for converting and validating data types commonly used in medical image processing.
def to_tuple(value: Any, length: int = 1) -> tuple[TypeNumber, ...]:
"""
Convert variable to tuple of specified length.
If value is iterable, length is ignored and tuple(value) is returned.
If value is scalar, it's repeated length times.
Parameters:
- value: Value to convert to tuple
- length: Target tuple length
Returns:
Tuple of specified length
Examples:
>>> to_tuple(1, length=3)
(1, 1, 1)
>>> to_tuple([1, 2], length=3)
(1, 2)
"""
def check_sequence(sequence: Sequence, name: str) -> None:
"""
Validate sequence parameters.
Parameters:
- sequence: Sequence to validate
- name: Name for error messages
Raises:
ValueError: If sequence is invalid
"""
def is_iterable(object: Any) -> bool:
"""
Check if object is iterable.
Parameters:
- object: Object to check
Returns:
True if object is iterable, False otherwise
"""
def parse_spatial_shape(shape) -> tuple[int, int, int]:
"""
Parse spatial shape specification into 3D tuple.
Parameters:
- shape: Shape specification (int or sequence of ints)
Returns:
3D spatial shape tuple
"""
def guess_type(string: str) -> Any:
"""
Guess variable type from string representation.
Attempts to convert string to appropriate Python type
(int, float, bool, None, or keeps as string).
Parameters:
- string: String to parse
Returns:
Parsed value with appropriate type
"""Functions for handling file paths, stems, and compression operations commonly needed in medical imaging workflows.
def get_stem(path: TypePath) -> str:
"""
Extract file stem from path (filename without extension).
Handles complex medical imaging extensions like .nii.gz
Parameters:
- path: File path
Returns:
File stem without extensions
Examples:
>>> get_stem('image.nii.gz')
'image'
>>> get_stem('/path/to/scan.dcm')
'scan'
"""
def normalize_path(path: TypePath) -> Path:
"""
Normalize file path to Path object.
Parameters:
- path: Path as string or Path object
Returns:
Normalized Path object
"""
def get_torchio_cache_dir() -> Path:
"""
Get TorchIO cache directory for temporary files.
Returns platform-appropriate cache directory for TorchIO,
creating it if it doesn't exist.
Returns:
Path to TorchIO cache directory
"""
def compress(
input_path: TypePath,
output_path: TypePath = None
) -> Path:
"""
Compress file using gzip compression.
Parameters:
- input_path: Path to file to compress
- output_path: Output path (optional, defaults to input_path + '.gz')
Returns:
Path to compressed file
"""Utility functions for creating datasets, applying transforms to files, and managing data operations.
def create_dummy_dataset(
num_subjects: int,
size_range: tuple[int, int] = (10, 20),
directory: Path = None,
random_seed: int = 42,
**kwargs
) -> SubjectsDataset:
"""
Create dummy dataset for testing and development.
Generates synthetic medical imaging data with realistic
structure for testing transforms and training loops.
Parameters:
- num_subjects: Number of subjects to create
- size_range: Range of image sizes (min, max)
- directory: Directory to save dummy images (None for temporary)
- random_seed: Random seed for reproducibility
Returns:
SubjectsDataset with dummy subjects
"""
def apply_transform_to_file(
transform: 'Transform',
input_path: TypePath,
output_path: TypePath,
verbose: bool = False
) -> None:
"""
Apply transform to image file and save result.
Loads image from input_path, applies transform, and saves
to output_path. Useful for batch processing of images.
Parameters:
- transform: Transform to apply
- input_path: Path to input image
- output_path: Path to save transformed image
- verbose: Whether to print progress information
"""Functions for handling batched medical image data and extracting information from data loaders.
def history_collate(batch: Sequence, collate_transforms: bool = True) -> dict:
"""
Custom collate function that preserves transform history.
Parameters:
- batch: Sequence of Subject instances
- collate_transforms: Whether to collate transform histories
Returns:
Collated batch dictionary with preserved metadata
"""
def get_first_item(data_loader: torch.utils.data.DataLoader):
"""
Get first item from data loader for inspection.
Useful for debugging data loading pipelines and
inspecting batch structure.
Parameters:
- data_loader: DataLoader to get item from
Returns:
First item from data loader
"""
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
"""
Extract image names and batch size from batch.
Parameters:
- batch: Batch dictionary from TorchIO data loader
Returns:
Tuple of (list of image names, batch size)
"""
def get_subjects_from_batch(batch: dict) -> list:
"""
Extract individual subjects from batched data.
Reconstructs Subject instances from collated batch data.
Parameters:
- batch: Batched data dictionary
Returns:
List of Subject instances
"""
def add_images_from_batch(
images_dict: dict,
batch: dict,
batch_idx: int
) -> None:
"""
Add images from batch to images dictionary.
Parameters:
- images_dict: Dictionary to add images to
- batch: Batch containing images
- batch_idx: Index of current batch
"""Functions for system compatibility checks and introspection.
def get_major_sitk_version() -> int:
"""
Get major version of SimpleITK library.
Returns:
Major version number of installed SimpleITK
"""
def get_subclasses(target_class: type) -> list[type]:
"""
Get all subclasses of a class recursively.
Parameters:
- target_class: Class to find subclasses of
Returns:
List of all subclasses (direct and indirect)
"""
def guess_external_viewer() -> Path | None:
"""
Guess external image viewer application.
Attempts to find suitable medical image viewer
on the current system (ITK-SNAP, 3D Slicer, etc.).
Returns:
Path to viewer executable, or None if not found
"""Comprehensive type definitions for medical image processing, providing type hints and validation for TorchIO operations.
# Path and basic types
TypePath = Union[str, Path]
TypeNumber = Union[int, float]
TypeKeys = Optional[Sequence[str]]
# Data types
TypeData = Union[torch.Tensor, np.ndarray]
TypeDataAffine = tuple[torch.Tensor, np.ndarray]
TypeSlice = Union[int, slice]
# Geometric types
TypeDoubletInt = tuple[int, int]
TypeTripletInt = tuple[int, int, int]
TypeQuartetInt = tuple[int, int, int, int]
TypeSextetInt = tuple[int, int, int, int, int, int]
TypeDoubleFloat = tuple[float, float]
TypeTripletFloat = tuple[float, float, float]
TypeQuartetFloat = tuple[float, float, float, float]
TypeSextetFloat = tuple[float, float, float, float, float, float]
# Spatial types
TypeTuple = Union[int, TypeTripletInt]
TypeRangeInt = Union[int, TypeDoubletInt]
TypeSpatialShape = Union[int, TypeTripletInt]
TypeSpacing = Union[float, TypeTripletFloat]
TypeRangeFloat = Union[float, TypeDoubleFloat]
# Direction cosine matrices
TypeDirection2D = TypeQuartetFloat
TypeDirection3D = tuple[
float, float, float,
float, float, float,
float, float, float,
]
TypeDirection = Union[TypeDirection2D, TypeDirection3D]
# Functional types
TypeCallable = Callable[[torch.Tensor], torch.Tensor]Essential constants used throughout TorchIO for consistent behavior and standardization.
# Image type constants
INTENSITY = 'intensity' # For scalar/intensity images
LABEL = 'label' # For segmentation/label images
SAMPLING_MAP = 'sampling_map' # For sampling probability maps
# Dictionary keys for subject data
PATH = 'path' # File path key
TYPE = 'type' # Image type key
STEM = 'stem' # File stem key
DATA = 'data' # Image data key
AFFINE = 'affine' # Affine matrix key
TENSOR = 'tensor' # Tensor data key
# Aggregator and queue keys
IMAGE = 'image' # Image key for aggregator
LOCATION = 'location' # Location key for aggregator
HISTORY = 'history' # Transform history key
NUM_SAMPLES = 'num_samples' # Queue samples key
# Technical constants
CHANNELS_DIMENSION = 1 # PyTorch channel dimension index
MIN_FLOAT_32 = torch.finfo(torch.float32).eps # Float32 epsilon
# Repository URLs
REPO_URL = 'https://github.com/TorchIO-project/torchio/'
DATA_REPO = 'https://github.com/TorchIO-project/torchio-data/raw/main/data/'TorchIO provides command-line tools for common medical image processing tasks.
# CLI applications (available as shell commands after pip install)
# tiohd - Print image information
# tiotr - Apply transforms
# torchio-transform - Apply transforms (alias)Command-line usage examples:
# Print image information
tiohd image.nii.gz
# Apply transform to image
tiotr input.nii.gz output.nii.gz --transform RandomFlip
# Apply multiple transforms
tiotr input.nii.gz output.nii.gz \
--transform ToCanonical \
--transform "Resample(1)" \
--transform "RandomAffine(degrees=5)"Utilities for downloading datasets and managing file integrity.
def calculate_md5(fpath: TypePath, chunk_size: int = 1024 * 1024) -> str:
"""Calculate MD5 hash of file"""
def check_md5(fpath: TypePath, md5: str, **kwargs) -> bool:
"""Verify MD5 hash of file"""
def check_integrity(fpath: TypePath, md5: str = None) -> bool:
"""Check file exists and optionally verify MD5"""
def download_url(
url: str,
root: TypePath,
filename: str = None,
md5: str = None
) -> None:
"""Download file from URL with progress bar and integrity checking"""
def extract_archive(
from_path: TypePath,
to_path: TypePath = None,
remove_finished: bool = False
) -> None:
"""Extract compressed archive (zip, tar, tar.gz, etc.)"""import torchio as tio
from pathlib import Path
# Create dummy dataset for testing
dummy_dataset = tio.utils.create_dummy_dataset(
num_subjects=10,
size_range=(32, 48),
directory=Path('./test_data'),
random_seed=42
)
print(f"Created {len(dummy_dataset)} dummy subjects")
# Use for testing transforms
test_transform = tio.Compose([
tio.RandomFlip(),
tio.RandomAffine(degrees=5),
tio.RandomNoise(std=0.1),
])
dummy_dataset.set_transform(test_transform)# Apply transform to multiple files
transform = tio.Compose([
tio.ToCanonical(),
tio.Resample(1),
tio.ZNormalization(),
])
input_files = Path('./input').glob('*.nii.gz')
output_dir = Path('./processed')
output_dir.mkdir(exist_ok=True)
for input_file in input_files:
output_file = output_dir / input_file.name
tio.utils.apply_transform_to_file(
transform=transform,
input_path=input_file,
output_path=output_file,
verbose=True
)# Inspect data loader batches
dataset = tio.SubjectsDataset([...]) # Your subjects
loader = tio.SubjectsLoader(dataset, batch_size=4)
# Get first batch for inspection
first_batch = tio.utils.get_first_item(loader)
# Extract batch information
image_names, batch_size = tio.utils.get_batch_images_and_size(first_batch)
print(f"Batch contains {batch_size} subjects with images: {image_names}")
# Reconstruct individual subjects from batch
subjects = tio.utils.get_subjects_from_batch(first_batch)
print(f"Reconstructed {len(subjects)} subjects from batch")Install with Tessl CLI
npx tessl i tessl/pypi-torchio