0
# Feature Extraction
1
2
Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.
3
4
## Capabilities
5
6
### Feature Extractor Creation
7
8
Create feature extractors that can extract intermediate representations from any layer of a model.
9
10
```python { .api }
11
def create_feature_extractor(
12
model: torch.nn.Module,
13
return_nodes: Union[Dict[str, str], List[str]],
14
train_return_nodes: Union[Dict[str, str], List[str]] = None,
15
suppress_diff_warnings: bool = False,
16
tracer_kwargs: Dict[str, Any] = None,
17
**kwargs
18
):
19
"""
20
Create a feature extractor from a model using FX graph tracing.
21
22
Args:
23
model: Source model to extract features from
24
return_nodes: Nodes to return features from. Can be dict mapping
25
node names to output names, or list of node names
26
train_return_nodes: Different nodes for training mode
27
suppress_diff_warnings: Suppress warnings about model differences
28
tracer_kwargs: Additional arguments for FX tracer
29
**kwargs: Additional arguments
30
31
Returns:
32
Feature extractor model that returns specified intermediate features
33
"""
34
35
def get_graph_node_names(
36
model: torch.nn.Module,
37
tracer_kwargs: Dict[str, Any] = None,
38
suppress_diff_warnings: bool = False
39
) -> Tuple[List[str], List[str]]:
40
"""
41
Get node names from model's FX graph for feature extraction.
42
43
Args:
44
model: Model to analyze
45
tracer_kwargs: Additional tracer arguments
46
suppress_diff_warnings: Suppress model difference warnings
47
48
Returns:
49
Tuple of (node_names, node_types) for available extraction points
50
"""
51
```
52
53
## Feature Extraction Classes
54
55
### Hook-Based Feature Extraction
56
57
```python { .api }
58
class FeatureInfo:
59
"""
60
Information about extracted features.
61
62
Args:
63
feature_info: List of feature information dictionaries
64
out_indices: Output indices for features
65
"""
66
67
def __init__(
68
self,
69
feature_info: List[Dict[str, Any]],
70
out_indices: List[int]
71
): ...
72
73
def get_dicts(self, keys: List[str] = None) -> List[Dict[str, Any]]:
74
"""Get feature info as list of dictionaries."""
75
76
def channels(self, idx: int = None) -> Union[List[int], int]:
77
"""Get feature channels."""
78
79
def reduction(self, idx: int = None) -> Union[List[int], int]:
80
"""Get feature reduction factors."""
81
82
class FeatureHooks:
83
"""
84
Feature extraction using forward hooks.
85
86
Args:
87
hooks: List of hook functions
88
named_modules: Dictionary of named modules
89
out_map: Output mapping for feature names
90
"""
91
92
def __init__(
93
self,
94
hooks: List[Callable],
95
named_modules: Dict[str, torch.nn.Module],
96
out_map: List[int] = None
97
): ...
98
99
def get_output(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
100
"""Get hooked features from forward pass."""
101
102
class FeatureHookNet(torch.nn.Module):
103
"""
104
Wrapper that uses hooks to extract features during forward pass.
105
106
Args:
107
model: Base model to wrap
108
out_indices: Indices of layers to extract features from
109
out_map: Mapping of output names
110
return_interm: Return intermediate features
111
**kwargs: Additional arguments
112
"""
113
114
def __init__(
115
self,
116
model: torch.nn.Module,
117
out_indices: List[int],
118
out_map: List[str] = None,
119
return_interm: bool = False,
120
**kwargs
121
): ...
122
123
class FeatureListNet(torch.nn.Module):
124
"""
125
Wrapper that returns features as a list.
126
127
Args:
128
model: Base model to wrap
129
out_indices: Indices of layers to extract features from
130
**kwargs: Additional arguments
131
"""
132
133
def __init__(
134
self,
135
model: torch.nn.Module,
136
out_indices: List[int],
137
**kwargs
138
): ...
139
140
class FeatureDictNet(torch.nn.Module):
141
"""
142
Wrapper that returns features as a dictionary.
143
144
Args:
145
model: Base model to wrap
146
out_indices: Indices of layers to extract features from
147
out_map: Names for output features
148
**kwargs: Additional arguments
149
"""
150
151
def __init__(
152
self,
153
model: torch.nn.Module,
154
out_indices: List[int],
155
out_map: List[str] = None,
156
**kwargs
157
): ...
158
```
159
160
### FX-Based Feature Extraction
161
162
```python { .api }
163
class FeatureGraphNet(torch.nn.Module):
164
"""
165
FX-based feature extraction network.
166
167
Args:
168
model: Base model
169
out_indices: Output layer indices
170
out_map: Feature name mapping
171
**kwargs: Additional arguments
172
"""
173
174
def __init__(
175
self,
176
model: torch.nn.Module,
177
out_indices: List[int],
178
out_map: List[str] = None,
179
**kwargs
180
): ...
181
182
class GraphExtractNet(torch.nn.Module):
183
"""
184
Graph-based feature extraction using FX.
185
186
Args:
187
model: Source model
188
return_nodes: Nodes to extract features from
189
**kwargs: Additional arguments
190
"""
191
192
def __init__(
193
self,
194
model: torch.nn.Module,
195
return_nodes: Dict[str, str],
196
**kwargs
197
): ...
198
```
199
200
## Model Manipulation
201
202
### Model Analysis and Modification
203
204
```python { .api }
205
def model_parameters(
206
model: torch.nn.Module,
207
exclude_head: bool = False,
208
recurse: bool = True
209
) -> Iterator[torch.nn.Parameter]:
210
"""
211
Get model parameters with filtering options.
212
213
Args:
214
model: Model to analyze
215
exclude_head: Exclude classifier/head parameters
216
recurse: Recurse into submodules
217
218
Returns:
219
Iterator over model parameters
220
"""
221
222
def named_apply(
223
fn: Callable,
224
module: torch.nn.Module,
225
name: str = '',
226
depth_first: bool = True,
227
include_root: bool = False
228
) -> torch.nn.Module:
229
"""
230
Apply function to named modules recursively.
231
232
Args:
233
fn: Function to apply to each module
234
module: Root module
235
name: Current module name
236
depth_first: Apply depth-first traversal
237
include_root: Include root module
238
239
Returns:
240
Modified module
241
"""
242
243
def named_modules(
244
module: torch.nn.Module,
245
memo: set = None,
246
prefix: str = '',
247
remove_duplicate: bool = True
248
) -> Iterator[Tuple[str, torch.nn.Module]]:
249
"""
250
Get named modules with filtering.
251
252
Args:
253
module: Root module
254
memo: Set for tracking duplicates
255
prefix: Name prefix
256
remove_duplicate: Remove duplicate modules
257
258
Returns:
259
Iterator of (name, module) pairs
260
"""
261
262
def group_modules(
263
module: torch.nn.Module,
264
group_matcher: Callable,
265
output_values: bool = False,
266
reverse: bool = False
267
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Module]]]:
268
"""
269
Group modules by matching criteria.
270
271
Args:
272
module: Module to group
273
group_matcher: Function to determine group membership
274
output_values: Return module objects instead of names
275
reverse: Reverse the grouping order
276
277
Returns:
278
Dictionary mapping group IDs to module names/objects
279
"""
280
281
def group_parameters(
282
module: torch.nn.Module,
283
group_matcher: Callable,
284
output_values: bool = False,
285
reverse: bool = False
286
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Parameter]]]:
287
"""
288
Group parameters by matching criteria.
289
290
Args:
291
module: Module to analyze
292
group_matcher: Function to determine group membership
293
output_values: Return parameter objects instead of names
294
reverse: Reverse the grouping order
295
296
Returns:
297
Dictionary mapping group IDs to parameter names/objects
298
"""
299
300
def checkpoint_seq(
301
functions: List[Callable],
302
segments: int = 1,
303
input: torch.Tensor = None,
304
**kwargs
305
) -> torch.Tensor:
306
"""
307
Apply gradient checkpointing to sequence of functions.
308
309
Args:
310
functions: List of functions to apply
311
segments: Number of checkpoint segments
312
input: Input tensor
313
**kwargs: Additional arguments
314
315
Returns:
316
Output tensor with gradient checkpointing applied
317
"""
318
```
319
320
### Model Adaptation
321
322
```python { .api }
323
def adapt_input_conv(
324
model: torch.nn.Module,
325
in_chans: int,
326
conv_layer: str = None
327
) -> torch.nn.Module:
328
"""
329
Adapt model's input convolution for different channel counts.
330
331
Args:
332
model: Model to adapt
333
in_chans: New number of input channels
334
conv_layer: Name of convolution layer to adapt
335
336
Returns:
337
Model with adapted input convolution
338
"""
339
340
def load_pretrained(
341
model: torch.nn.Module,
342
cfg: Dict[str, Any] = None,
343
num_classes: int = 1000,
344
in_chans: int = 3,
345
filter_fn: Callable = None,
346
strict: bool = True,
347
progress: bool = False
348
) -> None:
349
"""
350
Load pretrained weights into model.
351
352
Args:
353
model: Model to load weights into
354
cfg: Pretrained configuration
355
num_classes: Number of output classes
356
in_chans: Number of input channels
357
filter_fn: Function to filter state dict keys
358
strict: Strict loading mode
359
progress: Show download progress
360
"""
361
362
def load_custom_pretrained(
363
model: torch.nn.Module,
364
cfg: Dict[str, Any] = None,
365
load_fn: Callable = None,
366
progress: bool = False,
367
check_hash: bool = False
368
) -> None:
369
"""
370
Load custom pretrained weights.
371
372
Args:
373
model: Model to load weights into
374
cfg: Configuration dictionary
375
load_fn: Custom loading function
376
progress: Show progress
377
check_hash: Verify file hash
378
"""
379
380
def build_model_with_cfg(
381
model_cls: Callable,
382
variant: str,
383
pretrained: bool,
384
pretrained_cfg: Dict[str, Any],
385
model_cfg: Dict[str, Any],
386
feature_cfg: Dict[str, Any],
387
**kwargs
388
) -> torch.nn.Module:
389
"""
390
Build model with configuration.
391
392
Args:
393
model_cls: Model class constructor
394
variant: Model variant name
395
pretrained: Load pretrained weights
396
pretrained_cfg: Pretrained configuration
397
model_cfg: Model configuration
398
feature_cfg: Feature extraction configuration
399
**kwargs: Additional model arguments
400
401
Returns:
402
Configured model instance
403
"""
404
```
405
406
## State Dictionary Utilities
407
408
### State Dict Manipulation
409
410
```python { .api }
411
def clean_state_dict(
412
state_dict: Dict[str, Any],
413
model: torch.nn.Module = None
414
) -> Dict[str, Any]:
415
"""
416
Clean state dictionary by removing unwanted keys.
417
418
Args:
419
state_dict: State dictionary to clean
420
model: Model to match against
421
422
Returns:
423
Cleaned state dictionary
424
"""
425
426
def load_state_dict(
427
checkpoint_path: str,
428
use_ema: bool = True,
429
device: torch.device = 'cpu'
430
) -> Dict[str, Any]:
431
"""
432
Load state dictionary from checkpoint file.
433
434
Args:
435
checkpoint_path: Path to checkpoint file
436
use_ema: Use EMA weights if available
437
device: Device to load tensors on
438
439
Returns:
440
Loaded state dictionary
441
"""
442
443
def load_checkpoint(
444
model: torch.nn.Module,
445
checkpoint_path: str,
446
use_ema: bool = False,
447
device: torch.device = 'cpu',
448
strict: bool = True
449
) -> None:
450
"""
451
Load checkpoint into model.
452
453
Args:
454
model: Model to load checkpoint into
455
checkpoint_path: Path to checkpoint file
456
use_ema: Use EMA weights if available
457
device: Device for loading
458
strict: Strict loading mode
459
"""
460
461
def remap_state_dict(
462
state_dict: Dict[str, Any],
463
remap_dict: Dict[str, str]
464
) -> Dict[str, Any]:
465
"""
466
Remap state dictionary keys using mapping rules.
467
468
Args:
469
state_dict: Original state dictionary
470
remap_dict: Mapping from old keys to new keys
471
472
Returns:
473
Remapped state dictionary
474
"""
475
476
def resume_checkpoint(
477
model: torch.nn.Module,
478
checkpoint_path: str,
479
optimizer: torch.optim.Optimizer = None,
480
loss_scaler = None,
481
log_info: bool = True
482
) -> Dict[str, Any]:
483
"""
484
Resume training from checkpoint.
485
486
Args:
487
model: Model to resume
488
checkpoint_path: Path to checkpoint
489
optimizer: Optimizer to resume
490
loss_scaler: Loss scaler to resume
491
log_info: Log resume information
492
493
Returns:
494
Dictionary with resume information
495
"""
496
```
497
498
## Usage Examples
499
500
### Basic Feature Extraction
501
502
```python
503
import timm
504
from timm.models import create_feature_extractor
505
506
# Create a model
507
model = timm.create_model('resnet50', pretrained=True)
508
509
# Create feature extractor for specific layers
510
feature_extractor = create_feature_extractor(
511
model,
512
return_nodes={
513
'layer1': 'feat1',
514
'layer2': 'feat2',
515
'layer3': 'feat3',
516
'layer4': 'feat4'
517
}
518
)
519
520
# Extract features
521
import torch
522
x = torch.randn(1, 3, 224, 224)
523
features = feature_extractor(x)
524
print(f"Feature shapes: {[(k, v.shape) for k, v in features.items()]}")
525
```
526
527
### Hook-Based Feature Extraction
528
529
```python
530
from timm.models import FeatureListNet
531
532
# Create model that returns features as list
533
model = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)
534
535
# Or wrap existing model
536
base_model = timm.create_model('resnet34', pretrained=True)
537
feature_model = FeatureListNet(base_model, out_indices=[1, 2, 3, 4])
538
539
# Extract features
540
features = feature_model(x)
541
print(f"Number of feature maps: {len(features)}")
542
for i, feat in enumerate(features):
543
print(f"Feature {i}: {feat.shape}")
544
```
545
546
### Model Analysis
547
548
```python
549
from timm.models import get_graph_node_names, model_parameters
550
551
# Analyze model structure
552
model = timm.create_model('vit_base_patch16_224', pretrained=True)
553
554
# Get available nodes for feature extraction
555
node_names, node_types = get_graph_node_names(model)
556
print(f"Available nodes: {len(node_names)}")
557
print(f"Sample nodes: {node_names[:10]}")
558
559
# Count parameters
560
total_params = sum(p.numel() for p in model_parameters(model))
561
print(f"Total parameters: {total_params:,}")
562
563
# Count parameters excluding head
564
body_params = sum(p.numel() for p in model_parameters(model, exclude_head=True))
565
print(f"Body parameters: {body_params:,}")
566
```
567
568
### Model Adaptation
569
570
```python
571
from timm.models import adapt_input_conv, load_checkpoint
572
573
# Adapt model for different input channels (e.g., grayscale)
574
model = timm.create_model('resnet50', pretrained=True)
575
model = adapt_input_conv(model, in_chans=1)
576
577
# Load custom checkpoint
578
load_checkpoint(model, 'path/to/checkpoint.pth')
579
580
# Resume training
581
checkpoint_info = resume_checkpoint(
582
model,
583
'path/to/checkpoint.pth',
584
optimizer=optimizer,
585
log_info=True
586
)
587
start_epoch = checkpoint_info['epoch']
588
```
589
590
### Advanced Feature Configuration
591
592
```python
593
# Create model with specific feature configuration
594
model = timm.create_model(
595
'resnet50',
596
pretrained=True,
597
features_only=True,
598
out_indices=[1, 2, 3, 4], # Which stages to output
599
output_stride=16, # Overall output stride
600
global_pool='', # Disable global pooling
601
num_classes=0 # Remove classifier
602
)
603
604
# Get feature info
605
feature_info = model.feature_info.get_dicts()
606
for info in feature_info:
607
print(f"Layer: {info['module']}, Channels: {info['num_chs']}, Reduction: {info['reduction']}")
608
```
609
610
## Types
611
612
```python { .api }
613
from typing import Optional, Union, List, Dict, Callable, Any, Tuple, Iterator
614
import torch
615
616
# Feature extraction types
617
FeatureDict = Dict[str, torch.Tensor]
618
FeatureList = List[torch.Tensor]
619
NodeSpec = Union[Dict[str, str], List[str]]
620
621
# Model analysis types
622
ParameterIterator = Iterator[torch.nn.Parameter]
623
ModuleDict = Dict[str, torch.nn.Module]
624
ParameterDict = Dict[str, torch.nn.Parameter]
625
626
# State dict types
627
StateDict = Dict[str, Any]
628
RemapDict = Dict[str, str]
629
630
# Hook types
631
HookFunction = Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], None]
632
FilterFunction = Callable[[str, torch.nn.Parameter], bool]
633
```