0
# Patch Sampling Strategies
1
2
Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference. Essential for patch-based training and inference on large medical images.
3
4
## Capabilities
5
6
### Base Patch Sampler
7
8
Abstract base class for all patch sampling strategies, defining the interface for patch extraction from medical image subjects.
9
10
```python { .api }
11
class PatchSampler:
12
"""
13
Base class for patch sampling strategies.
14
15
Parameters:
16
- patch_size: Size of patches to extract (int or tuple of 3 ints)
17
"""
18
def __init__(self, patch_size: TypeSpatialShape): ...
19
20
def __call__(self, sample: Subject) -> dict:
21
"""
22
Extract patch from subject.
23
24
Parameters:
25
- sample: Subject to sample from
26
27
Returns:
28
Dictionary with patch data and location information
29
"""
30
31
@property
32
def patch_size(self) -> tuple[int, int, int]:
33
"""Get patch size as tuple"""
34
```
35
36
### Uniform Sampler
37
38
Random uniform patch sampling that extracts patches from random locations throughout the volume with equal probability for all valid positions.
39
40
```python { .api }
41
class UniformSampler(PatchSampler):
42
"""
43
Random uniform patch sampling.
44
45
Extracts patches from random locations with equal probability.
46
Useful for general training when no specific region needs emphasis.
47
48
Parameters:
49
- patch_size: Size of patches to extract
50
"""
51
def __init__(self, patch_size: TypeSpatialShape): ...
52
```
53
54
Usage example:
55
56
```python
57
import torchio as tio
58
59
# Create uniform sampler for 64x64x64 patches
60
sampler = tio.data.UniformSampler(patch_size=64)
61
62
# Create subject
63
subject = tio.Subject(
64
t1=tio.ScalarImage('t1.nii.gz'),
65
seg=tio.LabelMap('segmentation.nii.gz')
66
)
67
68
# Extract patch
69
patch = sampler(subject)
70
71
# Access patch data
72
t1_patch = patch['t1'][tio.DATA] # Shape: (1, 64, 64, 64)
73
seg_patch = patch['seg'][tio.DATA] # Shape: (1, 64, 64, 64)
74
location = patch[tio.LOCATION] # Patch location in original image
75
```
76
77
### Weighted Sampler
78
79
Weighted random patch sampling based on probability maps, allowing patches to be extracted with higher probability from specific regions of interest.
80
81
```python { .api }
82
class WeightedSampler(PatchSampler):
83
"""
84
Weighted random patch sampling based on probability maps.
85
86
Extracts patches with probability proportional to values in
87
probability map. Useful for focusing sampling on regions of interest.
88
89
Parameters:
90
- patch_size: Size of patches to extract
91
- probability_map: Name of image to use as probability map
92
"""
93
def __init__(
94
self,
95
patch_size: TypeSpatialShape,
96
probability_map: str
97
): ...
98
```
99
100
Usage example:
101
102
```python
103
# Create probability map (higher values = higher sampling probability)
104
probability_map = tio.ScalarImage('probability_map.nii.gz')
105
106
subject = tio.Subject(
107
t1=tio.ScalarImage('t1.nii.gz'),
108
probability_map=probability_map
109
)
110
111
# Create weighted sampler
112
sampler = tio.data.WeightedSampler(
113
patch_size=64,
114
probability_map='probability_map'
115
)
116
117
# Extract patch (more likely from high-probability regions)
118
patch = sampler(subject)
119
```
120
121
### Label Sampler
122
123
Label-focused patch sampling that extracts patches containing specific labels, with configurable probabilities for different label values. Ideal for training on segmentation tasks or focusing on specific anatomical structures.
124
125
```python { .api }
126
class LabelSampler(WeightedSampler):
127
"""
128
Patch sampling focused on specific labels.
129
130
Extracts patches that contain specified labels with configurable
131
probabilities. Automatically creates probability map from label image.
132
133
Parameters:
134
- patch_size: Size of patches to extract
135
- label_name: Name of label image in subject
136
- label_probabilities: Dict mapping label values to sampling probabilities
137
"""
138
def __init__(
139
self,
140
patch_size: TypeSpatialShape,
141
label_name: str,
142
label_probabilities: dict = None
143
): ...
144
```
145
146
Usage example:
147
148
```python
149
# Define sampling probabilities for different labels
150
label_probs = {
151
0: 0.1, # Background: low probability
152
1: 0.8, # Tumor: high probability
153
2: 0.4, # Edema: medium probability
154
3: 0.6, # Necrosis: medium-high probability
155
}
156
157
subject = tio.Subject(
158
t1=tio.ScalarImage('t1.nii.gz'),
159
seg=tio.LabelMap('tumor_segmentation.nii.gz')
160
)
161
162
# Create label sampler
163
sampler = tio.data.LabelSampler(
164
patch_size=64,
165
label_name='seg',
166
label_probabilities=label_probs
167
)
168
169
# Extract patch (more likely to contain tumor regions)
170
patch = sampler(subject)
171
```
172
173
### Grid Sampler
174
175
Grid-based patch sampling for systematic coverage of the entire volume, primarily used for inference to ensure complete volume coverage without gaps.
176
177
```python { .api }
178
class GridSampler(PatchSampler):
179
"""
180
Regular grid-based patch sampling for inference.
181
182
Extracts patches in a regular grid pattern to ensure complete
183
volume coverage. Typically used for inference rather than training.
184
185
Parameters:
186
- patch_size: Size of patches to extract
187
- patch_overlap: Overlap between adjacent patches (int or tuple)
188
"""
189
def __init__(
190
self,
191
patch_size: TypeSpatialShape,
192
patch_overlap: TypeSpatialShape = 0
193
): ...
194
195
def __iter__(self):
196
"""Iterate through all patches in grid"""
197
198
@property
199
def num_patches(self) -> int:
200
"""Total number of patches in grid"""
201
```
202
203
Usage example:
204
205
```python
206
subject = tio.Subject(
207
t1=tio.ScalarImage('t1.nii.gz')
208
)
209
210
# Create grid sampler with 50% overlap
211
sampler = tio.data.GridSampler(
212
patch_size=64,
213
patch_overlap=32 # 50% overlap
214
)
215
216
# Extract all patches systematically
217
all_patches = []
218
for patch in sampler(subject):
219
all_patches.append(patch)
220
221
print(f"Total patches: {len(all_patches)}")
222
```
223
224
### Grid Aggregator
225
226
Aggregates predictions from grid-sampled patches back into full-volume predictions, handling overlapping regions and ensuring proper reconstruction.
227
228
```python { .api }
229
class GridAggregator:
230
"""
231
Aggregates predictions from grid-sampled patches.
232
233
Reconstructs full-volume predictions from overlapping patches,
234
handling various aggregation strategies for overlapping regions.
235
236
Parameters:
237
- sampler: GridSampler used to extract patches
238
- overlap_mode: How to handle overlapping regions ('crop', 'average', 'hann')
239
"""
240
def __init__(
241
self,
242
sampler: GridSampler,
243
overlap_mode: str = 'crop'
244
): ...
245
246
def add_batch(
247
self,
248
batch_tensor: torch.Tensor,
249
batch_locations: torch.Tensor
250
):
251
"""Add batch of predictions to aggregator"""
252
253
def get_output_tensor(self) -> torch.Tensor:
254
"""Get aggregated full-volume prediction"""
255
```
256
257
Usage example:
258
259
```python
260
import torch
261
import torchio as tio
262
263
# Setup for inference
264
subject = tio.Subject(t1=tio.ScalarImage('t1.nii.gz'))
265
patch_size = 64
266
patch_overlap = 16
267
268
# Create grid sampler and aggregator
269
grid_sampler = tio.data.GridSampler(
270
patch_size=patch_size,
271
patch_overlap=patch_overlap
272
)
273
aggregator = tio.data.GridAggregator(
274
sampler=grid_sampler,
275
overlap_mode='average'
276
)
277
278
# Model inference on patches
279
model = load_your_model()
280
model.eval()
281
282
with torch.no_grad():
283
for patch in grid_sampler(subject):
284
# Get patch data and location
285
patch_tensor = patch['t1'][tio.DATA].unsqueeze(0) # Add batch dim
286
location = patch[tio.LOCATION]
287
288
# Run inference
289
prediction = model(patch_tensor)
290
291
# Add to aggregator
292
aggregator.add_batch(prediction, location.unsqueeze(0))
293
294
# Get full-volume prediction
295
full_prediction = aggregator.get_output_tensor()
296
print(f"Full prediction shape: {full_prediction.shape}")
297
```
298
299
### Sampling Utilities
300
301
Utility functions for patch sampling operations and analysis.
302
303
```python { .api }
304
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
305
"""Extract image names and batch size from patch batch"""
306
307
def parse_spatial_shape(shape) -> tuple[int, int, int]:
308
"""Parse spatial shape specification into 3D tuple"""
309
```