CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchio

Tools for medical image processing with PyTorch

Overview
Eval results
Files

sampling.mddocs/

Patch Sampling Strategies

Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference. Essential for patch-based training and inference on large medical images.

Capabilities

Base Patch Sampler

Abstract base class for all patch sampling strategies, defining the interface for patch extraction from medical image subjects.

class PatchSampler:
    """
    Base class for patch sampling strategies.
    
    Parameters:
    - patch_size: Size of patches to extract (int or tuple of 3 ints)
    """
    def __init__(self, patch_size: TypeSpatialShape): ...
    
    def __call__(self, sample: Subject) -> dict:
        """
        Extract patch from subject.
        
        Parameters:
        - sample: Subject to sample from
        
        Returns:
        Dictionary with patch data and location information
        """
    
    @property
    def patch_size(self) -> tuple[int, int, int]:
        """Get patch size as tuple"""

Uniform Sampler

Random uniform patch sampling that extracts patches from random locations throughout the volume with equal probability for all valid positions.

class UniformSampler(PatchSampler):
    """
    Random uniform patch sampling.
    
    Extracts patches from random locations with equal probability.
    Useful for general training when no specific region needs emphasis.
    
    Parameters:
    - patch_size: Size of patches to extract
    """
    def __init__(self, patch_size: TypeSpatialShape): ...

Usage example:

import torchio as tio

# Create uniform sampler for 64x64x64 patches
sampler = tio.data.UniformSampler(patch_size=64)

# Create subject
subject = tio.Subject(
    t1=tio.ScalarImage('t1.nii.gz'),
    seg=tio.LabelMap('segmentation.nii.gz')
)

# Extract patch
patch = sampler(subject)

# Access patch data
t1_patch = patch['t1'][tio.DATA]  # Shape: (1, 64, 64, 64)
seg_patch = patch['seg'][tio.DATA]  # Shape: (1, 64, 64, 64)
location = patch[tio.LOCATION]  # Patch location in original image

Weighted Sampler

Weighted random patch sampling based on probability maps, allowing patches to be extracted with higher probability from specific regions of interest.

class WeightedSampler(PatchSampler):
    """
    Weighted random patch sampling based on probability maps.
    
    Extracts patches with probability proportional to values in 
    probability map. Useful for focusing sampling on regions of interest.
    
    Parameters:
    - patch_size: Size of patches to extract
    - probability_map: Name of image to use as probability map
    """
    def __init__(
        self,
        patch_size: TypeSpatialShape,
        probability_map: str
    ): ...

Usage example:

# Create probability map (higher values = higher sampling probability)
probability_map = tio.ScalarImage('probability_map.nii.gz')

subject = tio.Subject(
    t1=tio.ScalarImage('t1.nii.gz'),
    probability_map=probability_map
)

# Create weighted sampler
sampler = tio.data.WeightedSampler(
    patch_size=64,
    probability_map='probability_map'
)

# Extract patch (more likely from high-probability regions)
patch = sampler(subject)

Label Sampler

Label-focused patch sampling that extracts patches containing specific labels, with configurable probabilities for different label values. Ideal for training on segmentation tasks or focusing on specific anatomical structures.

class LabelSampler(WeightedSampler):
    """
    Patch sampling focused on specific labels.
    
    Extracts patches that contain specified labels with configurable
    probabilities. Automatically creates probability map from label image.
    
    Parameters:
    - patch_size: Size of patches to extract
    - label_name: Name of label image in subject
    - label_probabilities: Dict mapping label values to sampling probabilities
    """
    def __init__(
        self,
        patch_size: TypeSpatialShape,
        label_name: str,
        label_probabilities: dict = None
    ): ...

Usage example:

# Define sampling probabilities for different labels
label_probs = {
    0: 0.1,   # Background: low probability
    1: 0.8,   # Tumor: high probability
    2: 0.4,   # Edema: medium probability
    3: 0.6,   # Necrosis: medium-high probability
}

subject = tio.Subject(
    t1=tio.ScalarImage('t1.nii.gz'),
    seg=tio.LabelMap('tumor_segmentation.nii.gz')
)

# Create label sampler
sampler = tio.data.LabelSampler(
    patch_size=64,
    label_name='seg',
    label_probabilities=label_probs
)

# Extract patch (more likely to contain tumor regions)
patch = sampler(subject)

Grid Sampler

Grid-based patch sampling for systematic coverage of the entire volume, primarily used for inference to ensure complete volume coverage without gaps.

class GridSampler(PatchSampler):
    """
    Regular grid-based patch sampling for inference.
    
    Extracts patches in a regular grid pattern to ensure complete
    volume coverage. Typically used for inference rather than training.
    
    Parameters:
    - patch_size: Size of patches to extract
    - patch_overlap: Overlap between adjacent patches (int or tuple)
    """
    def __init__(
        self,
        patch_size: TypeSpatialShape,
        patch_overlap: TypeSpatialShape = 0
    ): ...
    
    def __iter__(self):
        """Iterate through all patches in grid"""
    
    @property
    def num_patches(self) -> int:
        """Total number of patches in grid"""

Usage example:

subject = tio.Subject(
    t1=tio.ScalarImage('t1.nii.gz')
)

# Create grid sampler with 50% overlap
sampler = tio.data.GridSampler(
    patch_size=64,
    patch_overlap=32  # 50% overlap
)

# Extract all patches systematically
all_patches = []
for patch in sampler(subject):
    all_patches.append(patch)

print(f"Total patches: {len(all_patches)}")

Grid Aggregator

Aggregates predictions from grid-sampled patches back into full-volume predictions, handling overlapping regions and ensuring proper reconstruction.

class GridAggregator:
    """
    Aggregates predictions from grid-sampled patches.
    
    Reconstructs full-volume predictions from overlapping patches,
    handling various aggregation strategies for overlapping regions.
    
    Parameters:
    - sampler: GridSampler used to extract patches
    - overlap_mode: How to handle overlapping regions ('crop', 'average', 'hann')
    """
    def __init__(
        self,
        sampler: GridSampler,
        overlap_mode: str = 'crop'
    ): ...
    
    def add_batch(
        self,
        batch_tensor: torch.Tensor,
        batch_locations: torch.Tensor
    ):
        """Add batch of predictions to aggregator"""
    
    def get_output_tensor(self) -> torch.Tensor:
        """Get aggregated full-volume prediction"""

Usage example:

import torch
import torchio as tio

# Setup for inference
subject = tio.Subject(t1=tio.ScalarImage('t1.nii.gz'))
patch_size = 64
patch_overlap = 16

# Create grid sampler and aggregator
grid_sampler = tio.data.GridSampler(
    patch_size=patch_size,
    patch_overlap=patch_overlap
)
aggregator = tio.data.GridAggregator(
    sampler=grid_sampler,
    overlap_mode='average'
)

# Model inference on patches
model = load_your_model()
model.eval()

with torch.no_grad():
    for patch in grid_sampler(subject):
        # Get patch data and location
        patch_tensor = patch['t1'][tio.DATA].unsqueeze(0)  # Add batch dim
        location = patch[tio.LOCATION]
        
        # Run inference
        prediction = model(patch_tensor)
        
        # Add to aggregator
        aggregator.add_batch(prediction, location.unsqueeze(0))

# Get full-volume prediction
full_prediction = aggregator.get_output_tensor()
print(f"Full prediction shape: {full_prediction.shape}")

Sampling Utilities

Utility functions for patch sampling operations and analysis.

def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
    """Extract image names and batch size from patch batch"""

def parse_spatial_shape(shape) -> tuple[int, int, int]:
    """Parse spatial shape specification into 3D tuple"""

Install with Tessl CLI

npx tessl i tessl/pypi-torchio

docs

augmentation.md

composition.md

core-data-structures.md

data-loading.md

datasets.md

index.md

preprocessing.md

sampling.md

utilities.md

tile.json