Tools for medical image processing with PyTorch
Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference. Essential for patch-based training and inference on large medical images.
Abstract base class for all patch sampling strategies, defining the interface for patch extraction from medical image subjects.
class PatchSampler:
"""
Base class for patch sampling strategies.
Parameters:
- patch_size: Size of patches to extract (int or tuple of 3 ints)
"""
def __init__(self, patch_size: TypeSpatialShape): ...
def __call__(self, sample: Subject) -> dict:
"""
Extract patch from subject.
Parameters:
- sample: Subject to sample from
Returns:
Dictionary with patch data and location information
"""
@property
def patch_size(self) -> tuple[int, int, int]:
"""Get patch size as tuple"""Random uniform patch sampling that extracts patches from random locations throughout the volume with equal probability for all valid positions.
class UniformSampler(PatchSampler):
"""
Random uniform patch sampling.
Extracts patches from random locations with equal probability.
Useful for general training when no specific region needs emphasis.
Parameters:
- patch_size: Size of patches to extract
"""
def __init__(self, patch_size: TypeSpatialShape): ...Usage example:
import torchio as tio
# Create uniform sampler for 64x64x64 patches
sampler = tio.data.UniformSampler(patch_size=64)
# Create subject
subject = tio.Subject(
t1=tio.ScalarImage('t1.nii.gz'),
seg=tio.LabelMap('segmentation.nii.gz')
)
# Extract patch
patch = sampler(subject)
# Access patch data
t1_patch = patch['t1'][tio.DATA] # Shape: (1, 64, 64, 64)
seg_patch = patch['seg'][tio.DATA] # Shape: (1, 64, 64, 64)
location = patch[tio.LOCATION] # Patch location in original imageWeighted random patch sampling based on probability maps, allowing patches to be extracted with higher probability from specific regions of interest.
class WeightedSampler(PatchSampler):
"""
Weighted random patch sampling based on probability maps.
Extracts patches with probability proportional to values in
probability map. Useful for focusing sampling on regions of interest.
Parameters:
- patch_size: Size of patches to extract
- probability_map: Name of image to use as probability map
"""
def __init__(
self,
patch_size: TypeSpatialShape,
probability_map: str
): ...Usage example:
# Create probability map (higher values = higher sampling probability)
probability_map = tio.ScalarImage('probability_map.nii.gz')
subject = tio.Subject(
t1=tio.ScalarImage('t1.nii.gz'),
probability_map=probability_map
)
# Create weighted sampler
sampler = tio.data.WeightedSampler(
patch_size=64,
probability_map='probability_map'
)
# Extract patch (more likely from high-probability regions)
patch = sampler(subject)Label-focused patch sampling that extracts patches containing specific labels, with configurable probabilities for different label values. Ideal for training on segmentation tasks or focusing on specific anatomical structures.
class LabelSampler(WeightedSampler):
"""
Patch sampling focused on specific labels.
Extracts patches that contain specified labels with configurable
probabilities. Automatically creates probability map from label image.
Parameters:
- patch_size: Size of patches to extract
- label_name: Name of label image in subject
- label_probabilities: Dict mapping label values to sampling probabilities
"""
def __init__(
self,
patch_size: TypeSpatialShape,
label_name: str,
label_probabilities: dict = None
): ...Usage example:
# Define sampling probabilities for different labels
label_probs = {
0: 0.1, # Background: low probability
1: 0.8, # Tumor: high probability
2: 0.4, # Edema: medium probability
3: 0.6, # Necrosis: medium-high probability
}
subject = tio.Subject(
t1=tio.ScalarImage('t1.nii.gz'),
seg=tio.LabelMap('tumor_segmentation.nii.gz')
)
# Create label sampler
sampler = tio.data.LabelSampler(
patch_size=64,
label_name='seg',
label_probabilities=label_probs
)
# Extract patch (more likely to contain tumor regions)
patch = sampler(subject)Grid-based patch sampling for systematic coverage of the entire volume, primarily used for inference to ensure complete volume coverage without gaps.
class GridSampler(PatchSampler):
"""
Regular grid-based patch sampling for inference.
Extracts patches in a regular grid pattern to ensure complete
volume coverage. Typically used for inference rather than training.
Parameters:
- patch_size: Size of patches to extract
- patch_overlap: Overlap between adjacent patches (int or tuple)
"""
def __init__(
self,
patch_size: TypeSpatialShape,
patch_overlap: TypeSpatialShape = 0
): ...
def __iter__(self):
"""Iterate through all patches in grid"""
@property
def num_patches(self) -> int:
"""Total number of patches in grid"""Usage example:
subject = tio.Subject(
t1=tio.ScalarImage('t1.nii.gz')
)
# Create grid sampler with 50% overlap
sampler = tio.data.GridSampler(
patch_size=64,
patch_overlap=32 # 50% overlap
)
# Extract all patches systematically
all_patches = []
for patch in sampler(subject):
all_patches.append(patch)
print(f"Total patches: {len(all_patches)}")Aggregates predictions from grid-sampled patches back into full-volume predictions, handling overlapping regions and ensuring proper reconstruction.
class GridAggregator:
"""
Aggregates predictions from grid-sampled patches.
Reconstructs full-volume predictions from overlapping patches,
handling various aggregation strategies for overlapping regions.
Parameters:
- sampler: GridSampler used to extract patches
- overlap_mode: How to handle overlapping regions ('crop', 'average', 'hann')
"""
def __init__(
self,
sampler: GridSampler,
overlap_mode: str = 'crop'
): ...
def add_batch(
self,
batch_tensor: torch.Tensor,
batch_locations: torch.Tensor
):
"""Add batch of predictions to aggregator"""
def get_output_tensor(self) -> torch.Tensor:
"""Get aggregated full-volume prediction"""Usage example:
import torch
import torchio as tio
# Setup for inference
subject = tio.Subject(t1=tio.ScalarImage('t1.nii.gz'))
patch_size = 64
patch_overlap = 16
# Create grid sampler and aggregator
grid_sampler = tio.data.GridSampler(
patch_size=patch_size,
patch_overlap=patch_overlap
)
aggregator = tio.data.GridAggregator(
sampler=grid_sampler,
overlap_mode='average'
)
# Model inference on patches
model = load_your_model()
model.eval()
with torch.no_grad():
for patch in grid_sampler(subject):
# Get patch data and location
patch_tensor = patch['t1'][tio.DATA].unsqueeze(0) # Add batch dim
location = patch[tio.LOCATION]
# Run inference
prediction = model(patch_tensor)
# Add to aggregator
aggregator.add_batch(prediction, location.unsqueeze(0))
# Get full-volume prediction
full_prediction = aggregator.get_output_tensor()
print(f"Full prediction shape: {full_prediction.shape}")Utility functions for patch sampling operations and analysis.
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
"""Extract image names and batch size from patch batch"""
def parse_spatial_shape(shape) -> tuple[int, int, int]:
"""Parse spatial shape specification into 3D tuple"""Install with Tessl CLI
npx tessl i tessl/pypi-torchio