0
# TorchIO
1
2
TorchIO is a comprehensive Python package designed for efficient processing of 3D medical images in deep learning applications built with PyTorch. It provides a complete toolkit for reading, preprocessing, sampling, augmenting, and writing medical images, featuring both standard computer vision transforms and domain-specific medical imaging transforms that simulate realistic artifacts such as MRI magnetic field inhomogeneity and k-space motion artifacts.
3
4
## Package Information
5
6
- **Package Name**: torchio
7
- **Language**: Python
8
- **Installation**: `pip install torchio`
9
10
## Core Imports
11
12
```python
13
import torchio as tio
14
```
15
16
For specific components:
17
18
```python
19
from torchio import Subject, ScalarImage, LabelMap
20
from torchio import SubjectsDataset, SubjectsLoader
21
from torchio import Compose, RandomFlip, RandomAffine
22
from torchio.data import UniformSampler, GridSampler
23
```
24
25
## Basic Usage
26
27
```python
28
import torchio as tio
29
30
# Create a subject with medical images
31
subject = tio.Subject(
32
t1=tio.ScalarImage('t1.nii.gz'),
33
t2=tio.ScalarImage('t2.nii.gz'),
34
seg=tio.LabelMap('seg.nii.gz'),
35
age=45,
36
name='Subject_001'
37
)
38
39
# Define preprocessing and augmentation transforms
40
transform = tio.Compose([
41
tio.ToCanonical(), # Reorient to canonical orientation
42
tio.Resample(1), # Resample to 1mm isotropic
43
tio.CropOrPad((128, 128, 64)), # Crop or pad to target shape
44
tio.ZNormalization(), # Z-score normalization
45
tio.RandomFlip(), # Random flipping
46
tio.RandomAffine(), # Random affine transformation
47
tio.RandomNoise(std=0.1), # Add random noise
48
])
49
50
# Apply transforms to subject
51
transformed_subject = transform(subject)
52
53
# Create dataset for training
54
subjects = [subject1, subject2, subject3, ...] # List of subjects
55
dataset = tio.SubjectsDataset(subjects, transform=transform)
56
57
# Create data loader for patch-based training
58
patch_size = 64
59
samples_per_volume = 10
60
sampler = tio.data.UniformSampler(patch_size)
61
patches_queue = tio.Queue(
62
subjects_dataset=dataset,
63
max_length=100,
64
samples_per_volume=samples_per_volume,
65
sampler=sampler,
66
)
67
loader = torch.utils.data.DataLoader(patches_queue, batch_size=4)
68
69
# Training loop
70
for batch in loader:
71
# batch contains patches ready for training
72
inputs = batch['t1'][tio.DATA]
73
targets = batch['seg'][tio.DATA]
74
# ... train your model
75
```
76
77
## Architecture
78
79
TorchIO follows a hierarchical design optimized for medical image processing workflows:
80
81
- **Subject**: Dictionary-like container storing multiple medical images and metadata for a single patient/case
82
- **Image Types**: Specialized classes for different image modalities (ScalarImage for intensity images, LabelMap for segmentations)
83
- **Transform System**: Hierarchical transform classes (Transform → SpatialTransform/IntensityTransform) with history tracking
84
- **Sampling Strategies**: Flexible patch sampling for 3D volumes (uniform, weighted, label-based, grid-based)
85
- **Data Pipeline**: PyTorch-compatible datasets and loaders optimized for medical imaging workflows
86
87
## Capabilities
88
89
### Core Data Structures
90
91
Essential data structures for handling medical images, including the Subject container, various Image types, and dataset management for organizing multiple subjects.
92
93
```python { .api }
94
class Subject(dict):
95
def __init__(self, *args, **kwargs: dict[str, Any]): ...
96
def get_images(self, intensity_only: bool = True) -> list[Image]: ...
97
def check_consistent_spatial_shape(self): ...
98
99
class Image:
100
def __init__(self, path: TypePath, type: str = None, **kwargs): ...
101
@property
102
def data(self) -> torch.Tensor: ...
103
@property
104
def affine(self) -> np.ndarray: ...
105
106
class ScalarImage(Image):
107
"""Represents intensity/scalar medical images (e.g., MRI, CT scans)"""
108
109
class LabelMap(ScalarImage):
110
"""Represents segmentation/label images"""
111
112
class SubjectsDataset(torch.utils.data.Dataset):
113
def __init__(self, subjects: Sequence[Subject], transform: Transform = None): ...
114
```
115
116
[Core Data Structures](./core-data-structures.md)
117
118
### Data Loading and Management
119
120
PyTorch-compatible data loading utilities optimized for medical imaging, including specialized data loaders, queues for patch-based training, and efficient batch processing.
121
122
```python { .api }
123
class SubjectsLoader(torch.utils.data.DataLoader):
124
"""PyTorch DataLoader wrapper optimized for medical image subjects"""
125
126
class Queue(torch.utils.data.Dataset):
127
def __init__(
128
self,
129
subjects_dataset: SubjectsDataset,
130
max_length: int,
131
samples_per_volume: int,
132
sampler: PatchSampler,
133
**kwargs
134
): ...
135
```
136
137
[Data Loading and Management](./data-loading.md)
138
139
### Patch Sampling Strategies
140
141
Flexible sampling strategies for extracting patches from 3D medical volumes, supporting uniform sampling, weighted sampling based on probability maps, label-focused sampling, and grid-based sampling for inference.
142
143
```python { .api }
144
class PatchSampler:
145
"""Base class for patch sampling strategies"""
146
def __call__(self, sample: Subject) -> dict: ...
147
148
class UniformSampler(PatchSampler):
149
def __init__(self, patch_size: TypeSpatialShape): ...
150
151
class WeightedSampler(PatchSampler):
152
def __init__(self, patch_size: TypeSpatialShape, probability_map: str): ...
153
154
class LabelSampler(WeightedSampler):
155
def __init__(self, patch_size: TypeSpatialShape, label_name: str, label_probabilities: dict = None): ...
156
157
class GridSampler(PatchSampler):
158
def __init__(self, patch_size: TypeSpatialShape, patch_overlap: TypeSpatialShape = 0): ...
159
160
class GridAggregator:
161
def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'): ...
162
```
163
164
[Patch Sampling Strategies](./sampling.md)
165
166
### Preprocessing Transforms
167
168
Comprehensive preprocessing transforms for medical images including spatial transformations (resampling, cropping, padding), intensity normalization (z-score, histogram standardization), and specialized medical imaging preprocessing.
169
170
```python { .api }
171
# Spatial preprocessing
172
class Resample(SpatialTransform):
173
def __init__(self, target: TypeSpacing, image_interpolation: str = 'linear'): ...
174
175
class CropOrPad(SpatialTransform):
176
def __init__(self, target_shape: TypeSpatialShape, padding_mode: str = 'constant'): ...
177
178
class ToCanonical(SpatialTransform):
179
"""Reorients images to canonical orientation (RAS+)"""
180
181
# Intensity preprocessing
182
class ZNormalization(IntensityTransform):
183
def __init__(self, masking_method: str = None): ...
184
185
class RescaleIntensity(IntensityTransform):
186
def __init__(self, out_min_max: tuple[float, float] = (0, 1)): ...
187
188
class HistogramStandardization(IntensityTransform):
189
def __init__(self, landmarks: dict): ...
190
```
191
192
[Preprocessing Transforms](./preprocessing.md)
193
194
### Augmentation Transforms
195
196
Extensive augmentation transforms including spatial augmentations (affine, elastic deformation, flipping) and intensity augmentations with medical imaging-specific artifacts (motion, ghosting, bias field, spike artifacts).
197
198
```python { .api }
199
# Spatial augmentation
200
class RandomFlip(SpatialTransform):
201
def __init__(self, axes: TypeTuple = (0,), flip_probability: float = 0.5): ...
202
203
class RandomAffine(SpatialTransform):
204
def __init__(
205
self,
206
scales: TypeRangeFloat = None,
207
degrees: TypeRangeFloat = None,
208
translation: TypeRangeFloat = None,
209
**kwargs
210
): ...
211
212
class RandomElasticDeformation(SpatialTransform):
213
def __init__(self, num_control_points: TypeTuple = 7, max_displacement: TypeRangeFloat = 7.5): ...
214
215
# Medical imaging specific augmentation
216
class RandomMotion(IntensityTransform):
217
def __init__(self, degrees: TypeRangeFloat = 10, translation: TypeRangeFloat = 10): ...
218
219
class RandomBiasField(IntensityTransform):
220
def __init__(self, coefficients: TypeRangeFloat = 0.5): ...
221
222
class RandomGhosting(IntensityTransform):
223
def __init__(self, num_ghosts: tuple[int, int] = (4, 10), axes: tuple[int, ...] = (0, 1, 2)): ...
224
```
225
226
[Augmentation Transforms](./augmentation.md)
227
228
### Transform Composition
229
230
Tools for combining and organizing transforms into pipelines, including sequential composition, random selection from transform groups, and custom lambda transforms.
231
232
```python { .api }
233
class Compose(Transform):
234
def __init__(self, transforms: Sequence[Transform]): ...
235
236
class OneOf(Transform):
237
def __init__(self, transforms: dict[Transform, float]): ...
238
239
class Lambda(Transform):
240
def __init__(self, function: Callable, types_to_apply: tuple[type, ...] = None): ...
241
```
242
243
[Transform Composition](./composition.md)
244
245
### Medical Image Datasets
246
247
Pre-built datasets for common medical imaging research, including brain atlases, public medical imaging challenges, and synthetic datasets for testing and development.
248
249
```python { .api }
250
# Brain atlases and templates
251
class Colin27(Subject):
252
"""Colin27 brain template"""
253
254
class ICBM2009CNonlinearSymmetric(Subject):
255
"""ICBM 2009c nonlinear symmetric brain template"""
256
257
# Public datasets
258
class IXI(SubjectsDataset):
259
"""IXI dataset - brain MR images from healthy subjects"""
260
261
class RSNAMICCAI(SubjectsDataset):
262
"""RSNA-MICCAI Brain Tumor Radiogenomic Classification dataset"""
263
264
# MedMNIST 3D datasets
265
class OrganMNIST3D(SubjectsDataset):
266
"""3D organ segmentation dataset"""
267
```
268
269
[Medical Image Datasets](./datasets.md)
270
271
### Utilities and Constants
272
273
Helper functions, type definitions, and constants for medical image processing, including file I/O utilities, type conversion functions, and medical imaging constants.
274
275
```python { .api }
276
# Utility functions
277
def to_tuple(value: Any, length: int = 1) -> tuple[TypeNumber, ...]: ...
278
def apply_transform_to_file(transform: Transform, input_path: TypePath, output_path: TypePath): ...
279
def get_torchio_cache_dir() -> Path: ...
280
281
# Type definitions
282
TypePath = Union[str, Path]
283
TypeSpatialShape = Union[int, tuple[int, int, int]]
284
TypeSpacing = Union[float, tuple[float, float, float]]
285
286
# Constants
287
INTENSITY = 'intensity'
288
LABEL = 'label'
289
DATA = 'data'
290
AFFINE = 'affine'
291
```
292
293
[Utilities and Constants](./utilities.md)
294
295
### Command Line Interface
296
297
CLI tools for common medical image processing operations, providing convenient command-line access to TorchIO functionality.
298
299
```python { .api }
300
# Available CLI commands:
301
# tiohd - Print image information and optionally display
302
# tiotr/torchio-transform - Apply transforms to images from command line
303
```
304
305
**tiohd** - Print image header information and optionally visualize:
306
- Options: `--plot/-p` (plot using matplotlib), `--show/-s` (show in external viewer), `--label/-l` (treat as label image), `--load` (load data for memory info)
307
- Usage: `tiohd input.nii.gz --plot --show`
308
309
**tiotr/torchio-transform** - Apply transforms to images:
310
- Arguments: `input_path`, `transform_name`, `output_path`
311
- Options: `--kwargs/-k` (transform parameters), `--imclass/-c` (image class), `--seed/-s` (random seed), `--verbose/-v`
312
- Usage: `tiotr input.nii.gz RandomFlip output.nii.gz --kwargs "axes=(0,1)"`
313
314
## Types
315
316
```python { .api }
317
# Basic types
318
TypePath = Union[str, Path]
319
TypeNumber = Union[int, float]
320
TypeData = Union[torch.Tensor, np.ndarray]
321
TypeDataAffine = tuple[torch.Tensor, np.ndarray]
322
TypeSlice = Union[int, slice]
323
TypeKeys = Optional[Sequence[str]]
324
325
# Numeric tuple types
326
TypeDoubletInt = tuple[int, int]
327
TypeTripletInt = tuple[int, int, int]
328
TypeQuartetInt = tuple[int, int, int, int]
329
TypeSextetInt = tuple[int, int, int, int, int, int]
330
331
TypeDoubleFloat = tuple[float, float]
332
TypeTripletFloat = tuple[float, float, float]
333
TypeQuartetFloat = tuple[float, float, float, float]
334
TypeSextetFloat = tuple[float, float, float, float, float, float]
335
336
# Geometric types
337
TypeTuple = Union[int, TypeTripletInt]
338
TypeRangeInt = Union[int, TypeDoubletInt]
339
TypeSpatialShape = Union[int, TypeTripletInt]
340
TypeSpacing = Union[float, TypeTripletFloat]
341
TypeRangeFloat = Union[float, TypeDoubleFloat]
342
343
# Transform types
344
TypeCallable = Callable[[torch.Tensor], torch.Tensor]
345
346
# Direction matrix types
347
TypeDirection2D = TypeQuartetFloat
348
TypeDirection3D = tuple[float, float, float, float, float, float, float, float, float]
349
TypeDirection = Union[TypeDirection2D, TypeDirection3D]
350
351
# Image types
352
class Image:
353
"""Base class for medical images"""
354
355
class ScalarImage(Image):
356
"""Intensity/scalar medical images (MRI, CT, etc.)"""
357
358
class LabelMap(ScalarImage):
359
"""Segmentation/label images"""
360
361
class Subject(dict):
362
"""Container for multiple medical images and metadata"""
363
364
# Transform hierarchy
365
class Transform:
366
"""Base class for all transforms"""
367
368
class SpatialTransform(Transform):
369
"""Base class for spatial transformations"""
370
371
class IntensityTransform(Transform):
372
"""Base class for intensity transformations"""
373
```