Tools for medical image processing with PyTorch
PyTorch-compatible data loading utilities optimized for medical imaging workflows, including specialized data loaders, queues for patch-based training, and efficient batch processing of medical image subjects.
Specialized PyTorch DataLoader wrapper optimized for medical image subjects, providing efficient loading and batching of Subject instances with medical imaging-specific collate functions.
class SubjectsLoader(torch.utils.data.DataLoader):
"""
PyTorch DataLoader wrapper optimized for medical image subjects.
Provides custom collate function that handles the complex nested structure
of Subject instances containing multiple images and metadata.
"""
def __init__(
self,
dataset: SubjectsDataset,
batch_size: int = 1,
shuffle: bool = False,
num_workers: int = 0,
**kwargs
): ...Usage example:
import torchio as tio
import torch
# Create dataset
subjects = [...] # List of subjects
dataset = tio.SubjectsDataset(subjects)
# Create optimized loader for subjects
loader = tio.SubjectsLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=4
)
for batch in loader:
# batch is a dict containing batched subject data
t1_batch = batch['t1'][tio.DATA] # Shape: (4, 1, D, H, W)
ages = batch['age'] # List of agesQueue implementation for patch-based training that maintains a buffer of patches extracted from subjects using specified sampling strategies. Essential for training on large 3D medical volumes that don't fit in memory.
class Queue(torch.utils.data.Dataset):
"""
Queue for patch-based training with data augmentation.
Maintains a buffer of patches extracted from subjects, enabling
efficient training on large 3D volumes through patch-based sampling.
Parameters:
- subjects_dataset: SubjectsDataset containing subjects
- max_length: Maximum number of patches to keep in queue
- samples_per_volume: Number of patches to extract per subject
- sampler: PatchSampler for patch extraction strategy
- num_workers: Number of workers for parallel patch extraction
- shuffle_subjects: Whether to shuffle subjects
- shuffle_patches: Whether to shuffle patches
"""
def __init__(
self,
subjects_dataset: SubjectsDataset,
max_length: int,
samples_per_volume: int,
sampler: 'PatchSampler',
num_workers: int = 0,
shuffle_subjects: bool = True,
shuffle_patches: bool = True,
**kwargs
): ...
def __len__(self) -> int:
"""Return current number of patches in queue"""
def __getitem__(self, index: int) -> dict:
"""Get patch at index"""
def set_max_length(self, max_length: int):
"""Update maximum queue length"""Usage example:
# Create subjects dataset
subjects = [...] # List of subjects
dataset = tio.SubjectsDataset(subjects, transform=preprocessing_transform)
# Define patch sampling
patch_size = 64
sampler = tio.data.UniformSampler(patch_size)
# Create queue for patch-based training
patches_queue = tio.Queue(
subjects_dataset=dataset,
max_length=300, # Keep 300 patches in queue
samples_per_volume=10, # Extract 10 patches per subject
sampler=sampler,
num_workers=4,
shuffle_subjects=True,
shuffle_patches=True,
)
# Use with PyTorch DataLoader
patches_loader = torch.utils.data.DataLoader(
patches_queue,
batch_size=16,
num_workers=0 # Queue already handles parallelization
)
# Training loop
for batch in patches_loader:
inputs = batch['t1'][tio.DATA] # Shape: (16, 1, 64, 64, 64)
targets = batch['seg'][tio.DATA] # Shape: (16, 1, 64, 64, 64)
# Train model with patchesUtility functions for handling batched medical image data, including custom collate functions and batch analysis tools.
def history_collate(batch: Sequence, collate_transforms=True) -> dict:
"""
Custom collate function that preserves transform history.
Parameters:
- batch: Sequence of Subject instances
- collate_transforms: Whether to collate transform histories
Returns:
Collated batch dictionary
"""
def get_first_item(data_loader: torch.utils.data.DataLoader):
"""Get first item from data loader for inspection"""
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
"""
Extract image names and batch size from batch.
Returns:
Tuple of (list of image names, batch size)
"""
def get_subjects_from_batch(batch: dict) -> list:
"""Extract individual subjects from batched data"""
def add_images_from_batch(
images_dict: dict,
batch: dict,
batch_idx: int
):
"""Add images from batch to images dictionary"""Functions for efficient memory management when working with large medical image datasets.
def get_torchio_cache_dir() -> Path:
"""Get TorchIO cache directory for temporary files"""
def create_dummy_dataset(
num_subjects: int,
size_range: tuple[int, int] = (10, 20),
directory: Path = None,
**kwargs
) -> SubjectsDataset:
"""
Create dummy dataset for testing and development.
Parameters:
- num_subjects: Number of subjects to create
- size_range: Range of image sizes
- directory: Directory to save dummy images
Returns:
SubjectsDataset with dummy subjects
"""Usage example:
# Create dummy dataset for testing
dummy_dataset = tio.utils.create_dummy_dataset(
num_subjects=10,
size_range=(20, 30),
directory=Path('/tmp/dummy_medical_data')
)
# Use dummy dataset for testing transforms or training loops
dummy_loader = tio.SubjectsLoader(dummy_dataset, batch_size=2)
for batch in dummy_loader:
# Test your code with dummy data
passInstall with Tessl CLI
npx tessl i tessl/pypi-torchio