0
# TIMM (PyTorch Image Models)
1
2
A comprehensive collection of image models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and reference training/validation scripts that provide state-of-the-art computer vision models with reproducible ImageNet training results.
3
4
## Package Information
5
6
- **Package Name**: timm
7
- **Language**: Python
8
- **Installation**: `pip install timm`
9
10
## Core Imports
11
12
```python
13
import timm
14
```
15
16
Common patterns for model creation:
17
18
```python
19
from timm import create_model, list_models
20
```
21
22
For working with specific components:
23
24
```python
25
from timm.data import create_loader, create_transform
26
from timm.optim import create_optimizer_v2
27
from timm.scheduler import create_scheduler_v2
28
from timm.loss import LabelSmoothingCrossEntropy
29
```
30
31
## Basic Usage
32
33
```python
34
import timm
35
import torch
36
37
# Create a pretrained model
38
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
39
40
# List available models
41
available_models = timm.list_models('*resnet*') # All ResNet variants
42
pretrained_models = timm.list_pretrained('efficientnet*') # EfficientNet models with pretrained weights
43
44
# Create model for feature extraction
45
feature_model = timm.create_model('resnet50', pretrained=True, features_only=True)
46
47
# Inference on an image
48
model.eval()
49
with torch.no_grad():
50
# Input tensor should be [batch_size, 3, height, width]
51
input_tensor = torch.randn(1, 3, 224, 224)
52
predictions = model(input_tensor)
53
54
# Get model configuration
55
cfg = timm.get_pretrained_cfg('resnet50')
56
print(f"Model input size: {cfg['input_size']}")
57
print(f"Model mean: {cfg['mean']}")
58
print(f"Model std: {cfg['std']}")
59
```
60
61
## Architecture
62
63
TIMM is organized around several key components that work together to provide a complete computer vision ecosystem:
64
65
- **Models**: 1000+ pretrained models across 90+ architectures including Vision Transformers, ConvNets, and hybrid approaches
66
- **Layers**: Comprehensive collection of neural network building blocks optimized for vision tasks
67
- **Data**: Complete data loading, preprocessing, and augmentation pipeline
68
- **Training Infrastructure**: Optimizers, schedulers, loss functions, and utilities for model training
69
- **Scripts**: Production-ready training, validation, and inference scripts
70
71
The library's modular design allows users to mix and match components, from using pretrained models for inference to building custom training pipelines with TIMM's optimizers and data loaders.
72
73
## Capabilities
74
75
### Model Creation and Management
76
77
Core functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of architectures.
78
79
```python { .api }
80
def create_model(
81
model_name: str,
82
pretrained: bool = False,
83
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
84
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
85
checkpoint_path: Optional[Union[str, Path]] = None,
86
cache_dir: Optional[Union[str, Path]] = None,
87
scriptable: Optional[bool] = None,
88
exportable: Optional[bool] = None,
89
no_jit: Optional[bool] = None,
90
**kwargs: Any
91
) -> torch.nn.Module: ...
92
93
def list_models(
94
filter: Union[str, List[str]] = '',
95
module: Union[str, List[str]] = '',
96
pretrained: bool = False,
97
exclude_filters: Union[str, List[str]] = '',
98
name_matches_cfg: bool = False,
99
include_tags: Optional[bool] = None
100
) -> List[str]: ...
101
102
def list_pretrained(
103
filter: Union[str, List[str]] = '',
104
exclude_filters: str = ''
105
) -> List[str]: ...
106
107
def is_model(model_name: str) -> bool: ...
108
109
def list_modules() -> List[str]: ...
110
111
def model_entrypoint(
112
model_name: str,
113
module_filter: Optional[str] = None
114
) -> Callable[..., Any]: ...
115
116
def is_model_pretrained(model_name: str) -> bool: ...
117
118
def get_pretrained_cfg(
119
model_name: str,
120
allow_unregistered: bool = True
121
) -> Optional[PretrainedCfg]: ...
122
123
def get_pretrained_cfg_value(
124
model_name: str,
125
cfg_key: str
126
) -> Optional[Any]: ...
127
```
128
129
[Model Creation and Management](./models.md)
130
131
### Data Processing and Loading
132
133
Complete data pipeline including datasets, transforms, augmentation strategies, and high-performance data loaders optimized for computer vision training and inference.
134
135
```python { .api }
136
def create_loader(
137
dataset,
138
input_size: int | tuple,
139
batch_size: int,
140
is_training: bool = False,
141
use_prefetcher: bool = False,
142
no_aug: bool = False,
143
**kwargs
144
) -> torch.utils.data.DataLoader: ...
145
146
def create_transform(
147
input_size: int | tuple,
148
is_training: bool = False,
149
use_prefetcher: bool = False,
150
no_aug: bool = False,
151
scale: tuple = (0.08, 1.0),
152
ratio: tuple = (3./4., 4./3.),
153
**kwargs
154
): ...
155
156
def create_dataset(
157
name: str,
158
root: str,
159
split: str = 'validation',
160
is_training: bool = False,
161
**kwargs
162
): ...
163
```
164
165
[Data Processing and Loading](./data.md)
166
167
### Neural Network Layers and Components
168
169
Extensive collection of neural network building blocks including activations, attention mechanisms, convolutions, normalization layers, and specialized components for vision architectures.
170
171
```python { .api }
172
# Layer creation utilities
173
def create_conv2d(
174
in_channels: int,
175
out_channels: int,
176
kernel_size: Union[int, List[int]],
177
**kwargs
178
) -> torch.nn.Module: ...
179
180
def create_norm_layer(
181
layer_name: str,
182
num_features: int,
183
**kwargs
184
) -> torch.nn.Module: ...
185
186
def create_act_layer(
187
name: Optional[str],
188
inplace: Optional[bool] = None,
189
**kwargs
190
) -> Optional[torch.nn.Module]: ...
191
192
# Configuration functions
193
def is_scriptable() -> bool: ...
194
def is_exportable() -> bool: ...
195
def set_scriptable(mode: bool) -> object: ... # Context manager
196
def set_exportable(mode: bool) -> object: ... # Context manager
197
```
198
199
[Layers and Components](./layers.md)
200
201
### Training Infrastructure
202
203
Comprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.
204
205
```python { .api }
206
def create_optimizer_v2(
207
model_or_params,
208
opt: str = 'sgd',
209
lr: float = 0.01,
210
weight_decay: float = 0.0,
211
momentum: float = 0.9,
212
**kwargs
213
): ...
214
215
def create_scheduler_v2(
216
optimizer,
217
sched: str = 'step',
218
epochs: int = 200,
219
**kwargs
220
): ...
221
222
# Loss functions
223
class LabelSmoothingCrossEntropy(torch.nn.Module): ...
224
class SoftTargetCrossEntropy(torch.nn.Module): ...
225
```
226
227
[Training Infrastructure](./training.md)
228
229
### Model Analysis and Feature Extraction
230
231
Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.
232
233
```python { .api }
234
def create_feature_extractor(
235
model: torch.nn.Module,
236
return_nodes: dict | list,
237
**kwargs
238
): ...
239
240
class FeatureHookNet(torch.nn.Module): ...
241
class FeatureDictNet(torch.nn.Module): ...
242
243
# Model manipulation
244
def adapt_input_conv(
245
model: torch.nn.Module,
246
in_chans: int,
247
conv_layer: str = None
248
): ...
249
```
250
251
[Feature Extraction](./features.md)
252
253
### Utilities and Helpers
254
255
General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.
256
257
```python { .api }
258
# Model utilities
259
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: ...
260
def freeze(model: torch.nn.Module) -> None: ...
261
def unfreeze(model: torch.nn.Module) -> None: ...
262
263
# Training utilities
264
class ModelEma: ...
265
class CheckpointSaver: ...
266
class AverageMeter: ...
267
268
# Distributed training
269
def init_distributed_device(args) -> tuple: ...
270
def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: ...
271
```
272
273
[Utilities and Helpers](./utils.md)
274
275
## Types
276
277
```python { .api }
278
from typing import Optional, Union, List, Dict, Tuple, Callable, Any
279
from pathlib import Path
280
import torch
281
282
# Common type aliases used throughout TIMM
283
ModelType = torch.nn.Module
284
OptimizerType = torch.optim.Optimizer
285
SchedulerType = torch.optim.lr_scheduler._LRScheduler
286
TransformType = Callable[[Any], torch.Tensor]
287
DatasetType = torch.utils.data.Dataset
288
LoaderType = torch.utils.data.DataLoader
289
290
# Configuration types
291
ConfigDict = Dict[str, Any]
292
PretrainedCfg = Dict[str, Any]
293
ModelCfg = Dict[str, Any]
294
```