0
# Model Creation and Management
1
2
Comprehensive functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of 1000+ pretrained models across 90+ architectures.
3
4
## Capabilities
5
6
### Model Creation
7
8
Create model instances with extensive configuration options, including pretrained weights, custom number of classes, and architectural modifications.
9
10
```python { .api }
11
def create_model(
12
model_name: str,
13
pretrained: bool = False,
14
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
15
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
16
checkpoint_path: Optional[Union[str, Path]] = None,
17
cache_dir: Optional[Union[str, Path]] = None,
18
scriptable: Optional[bool] = None,
19
exportable: Optional[bool] = None,
20
no_jit: Optional[bool] = None,
21
**kwargs: Any
22
) -> torch.nn.Module:
23
"""
24
Create a model instance.
25
26
Args:
27
model_name: Name of model to instantiate
28
pretrained: Load pretrained weights if True
29
pretrained_cfg: Pretrained configuration override (dict or cfg name)
30
pretrained_cfg_overlay: Dictionary of config overrides
31
num_classes: Number of output classes (default: 1000)
32
in_chans: Number of input image channels (default: 3)
33
global_pool: Global pooling type override
34
scriptable: Set layer config so model is jit scriptable
35
exportable: Set layer config so model is traceable/ONNX exportable
36
no_jit: Disable jit related set/reset of layer config
37
checkpoint_path: Path to load checkpoint from instead of pretrained weights
38
cache_dir: Cache directory for downloaded pretrained weights
39
**kwargs: Model-specific arguments
40
41
Returns:
42
Instantiated model
43
"""
44
```
45
46
#### Usage Examples
47
48
```python
49
import timm
50
51
# Basic model creation
52
model = timm.create_model('resnet50', pretrained=True)
53
54
# Custom number of classes for fine-tuning
55
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
56
57
# Model for feature extraction
58
feature_model = timm.create_model('vit_base_patch16_224', pretrained=True, features_only=True)
59
60
# Model optimized for export
61
export_model = timm.create_model('resnet18', pretrained=True, scriptable=True, exportable=True)
62
63
# Load from custom checkpoint
64
model = timm.create_model('resnet50', checkpoint_path='/path/to/checkpoint.pth')
65
```
66
67
### Model Discovery
68
69
Functions to explore and filter the available model architectures and pretrained weights.
70
71
```python { .api }
72
def list_models(
73
filter: str = '',
74
module: str = '',
75
pretrained: bool = False,
76
exclude_filters: str = '',
77
name_matches_cfg: bool = False,
78
include_tags: bool = True
79
) -> list[str]:
80
"""
81
List available models.
82
83
Args:
84
filter: Wildcard filter string to limit model names
85
module: Specific module/architecture to limit results
86
pretrained: Only models with pretrained weights if True
87
exclude_filters: Exclude models matching these patterns
88
name_matches_cfg: Only models where name matches config
89
include_tags: Include model tags in results
90
91
Returns:
92
List of model names matching criteria
93
"""
94
95
def list_pretrained(filter: str = '') -> list[str]:
96
"""
97
List models with pretrained weights available.
98
99
Args:
100
filter: Wildcard filter for model names
101
102
Returns:
103
List of model names with pretrained weights
104
"""
105
106
def list_modules() -> list[str]:
107
"""
108
List available model modules/architectures.
109
110
Returns:
111
List of module names
112
"""
113
```
114
115
#### Usage Examples
116
117
```python
118
# List all models
119
all_models = timm.list_models()
120
121
# Filter models by architecture
122
resnet_models = timm.list_models('*resnet*')
123
vit_models = timm.list_models('vit_*')
124
125
# Only models with pretrained weights
126
pretrained_models = timm.list_models(pretrained=True)
127
128
# List specific architecture variants
129
efficientnet_pretrained = timm.list_pretrained('efficientnet*')
130
131
# Available model families
132
architectures = timm.list_modules()
133
```
134
135
### Model Validation
136
137
Utilities to validate model names and check availability of pretrained weights.
138
139
```python { .api }
140
def is_model(model_name: str) -> bool:
141
"""
142
Check if model name is valid and available.
143
144
Args:
145
model_name: Name to check
146
147
Returns:
148
True if model exists, False otherwise
149
"""
150
151
def is_model_pretrained(model_name: str) -> bool:
152
"""
153
Check if model has pretrained weights available.
154
155
Args:
156
model_name: Model name to check
157
158
Returns:
159
True if pretrained weights exist, False otherwise
160
"""
161
162
def model_entrypoint(model_name: str) -> Callable:
163
"""
164
Get the entrypoint function for a model.
165
166
Args:
167
model_name: Name of model
168
169
Returns:
170
Model creation function
171
"""
172
```
173
174
### Model Configuration
175
176
Access and retrieve model configuration and metadata.
177
178
```python { .api }
179
def get_pretrained_cfg(model_name: str) -> dict:
180
"""
181
Get pretrained configuration for model.
182
183
Args:
184
model_name: Name of model
185
186
Returns:
187
Dictionary containing model configuration including:
188
- input_size: Expected input dimensions
189
- mean: Normalization mean values
190
- std: Normalization standard deviation values
191
- num_classes: Number of output classes
192
- pool_size: Global pooling output size
193
- crop_pct: Center crop percentage
194
- interpolation: Resize interpolation method
195
- first_conv: Name of first convolutional layer
196
- classifier: Name of classifier layer
197
"""
198
199
def get_pretrained_cfg_value(model_name: str, cfg_key: str):
200
"""
201
Get specific configuration value for pretrained model.
202
203
Args:
204
model_name: Name of model
205
cfg_key: Configuration key to retrieve
206
207
Returns:
208
Configuration value for specified key
209
"""
210
```
211
212
#### Usage Examples
213
214
```python
215
# Get complete model configuration
216
cfg = timm.get_pretrained_cfg('resnet50')
217
print(f"Input size: {cfg['input_size']}")
218
print(f"Mean: {cfg['mean']}")
219
print(f"Std: {cfg['std']}")
220
221
# Get specific configuration values
222
input_size = timm.get_pretrained_cfg_value('efficientnet_b0', 'input_size')
223
crop_pct = timm.get_pretrained_cfg_value('vit_base_patch16_224', 'crop_pct')
224
225
# Validate model availability
226
if timm.is_model('my_custom_model'):
227
model = timm.create_model('my_custom_model')
228
229
# Check for pretrained weights
230
if timm.is_model_pretrained('resnet101'):
231
model = timm.create_model('resnet101', pretrained=True)
232
```
233
234
### Advanced Model Creation
235
236
Advanced patterns for model customization and creation.
237
238
#### Model Factory Functions
239
240
```python { .api }
241
def create_model_from_pretrained(
242
model_name: str,
243
pretrained_cfg: dict = None,
244
**model_kwargs
245
) -> torch.nn.Module:
246
"""
247
Create model using specific pretrained configuration.
248
249
Args:
250
model_name: Name of model to create
251
pretrained_cfg: Custom pretrained configuration
252
**model_kwargs: Additional model arguments
253
254
Returns:
255
Configured model instance
256
"""
257
```
258
259
#### Custom Model Registration
260
261
```python { .api }
262
def register_model(fn: Callable = None, *, name: str = None) -> Callable:
263
"""
264
Register a new model architecture.
265
266
Args:
267
fn: Model creation function
268
name: Optional model name override
269
270
Returns:
271
Decorated function
272
"""
273
```
274
275
#### Usage Examples
276
277
```python
278
# Register custom model
279
@timm.register_model
280
def my_custom_resnet(pretrained=False, **kwargs):
281
# Custom ResNet implementation
282
model = MyCustomResNet(**kwargs)
283
if pretrained:
284
# Load custom pretrained weights
285
pass
286
return model
287
288
# Use registered model
289
custom_model = timm.create_model('my_custom_resnet', pretrained=True)
290
```
291
292
### Hugging Face Hub Integration
293
294
TIMM provides seamless integration with Hugging Face Hub for loading models and configurations.
295
296
```python { .api }
297
def load_model_config_from_hf(model_id: str) -> dict:
298
"""
299
Load model configuration from Hugging Face Hub.
300
301
Args:
302
model_id: Hugging Face model identifier
303
304
Returns:
305
Model configuration dictionary
306
"""
307
308
def load_state_dict_from_hf(model_id: str) -> dict:
309
"""
310
Load model weights from Hugging Face Hub.
311
312
Args:
313
model_id: Hugging Face model identifier
314
315
Returns:
316
Model state dictionary
317
"""
318
```
319
320
#### Hub Model Loading Examples
321
322
```python
323
# Load model from Hugging Face Hub using hf-hub: prefix
324
model = timm.create_model('hf-hub:microsoft/resnet-50', pretrained=True)
325
326
# Load local model using local-dir: prefix
327
model = timm.create_model('local-dir:/path/to/model/folder', pretrained=True)
328
329
# Load specific model revision/branch
330
model = timm.create_model('hf-hub:microsoft/resnet-50@main', pretrained=True)
331
```
332
333
## Model Architecture Categories
334
335
TIMM includes models from the following major categories:
336
337
### Vision Transformers
338
- **ViT**: Vision Transformer variants (Base, Large, Huge)
339
- **DeiT**: Data-efficient Image Transformers
340
- **BEiT**: Bidirectional Encoder representation from Image Transformers
341
- **Swin**: Swin Transformer hierarchical models
342
- **CaiT**: Class-Attention in Image Transformers
343
- **CrossViT**: Cross-Attention Multi-Scale Vision Transformer
344
345
### Convolutional Networks
346
- **ResNet**: ResNet and ResNeXt families
347
- **EfficientNet**: EfficientNet B0-B8 and V2 variants
348
- **ConvNeXt**: Modern ConvNet architectures
349
- **RegNet**: Designing Network Design Spaces
350
- **DenseNet**: Densely Connected Convolutional Networks
351
- **MobileNet**: MobileNetV3 and variants
352
353
### Hybrid Architectures
354
- **ConViT**: Convolutions meet Vision Transformers
355
- **LeViT**: Vision Transformer in ConvNet's Clothing
356
- **CoAtNet**: Convolution and Attention networks
357
- **MaxViT**: Multi-Axis Vision Transformer
358
359
### Specialized Models
360
- **CLIP**: Vision encoders from CLIP models
361
- **BEiT3**: Multimodal foundation models
362
- **EVA**: Enhanced Vision Transformer
363
- **InternViT**: Large-scale vision foundation models
364
365
### Advanced Features
366
367
#### NaFlexViT (Native Flexible Vision Transformers)
368
TIMM supports variable aspect ratio and resolution training/inference through NaFlexViT integration.
369
370
```python
371
# Enable NaFlexViT for supported models
372
model = timm.create_model('vit_base_patch16_224', pretrained=True, use_naflex=True)
373
374
# Models with ROPE support can be loaded in NaFlexViT mode
375
model = timm.create_model('eva_large_patch14_196', pretrained=True, use_naflex=True)
376
```
377
378
#### Forward Intermediates API
379
Extract intermediate features from models during forward pass.
380
381
```python
382
# Enable intermediate feature extraction
383
model = timm.create_model('resnet50', pretrained=True)
384
features = model.forward_intermediates(x, indices=[1, 2, 3, 4])
385
```
386
387
## Types
388
389
```python { .api }
390
from typing import Optional, Union, List, Dict, Callable, Any
391
import torch
392
393
# Model configuration types
394
PretrainedCfg = Dict[str, Any]
395
ModelCfg = Dict[str, Any]
396
397
# Model creation function signature
398
ModelEntrypoint = Callable[..., torch.nn.Module]
399
400
# Filter types for model listing
401
ModelFilter = Union[str, List[str]]
402
```