or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mdfeatures.mdindex.mdlayers.mdmodels.mdtraining.mdutils.md

features.mddocs/

0

# Feature Extraction

1

2

Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.

3

4

## Capabilities

5

6

### Feature Extractor Creation

7

8

Create feature extractors that can extract intermediate representations from any layer of a model.

9

10

```python { .api }

11

def create_feature_extractor(

12

model: torch.nn.Module,

13

return_nodes: Union[Dict[str, str], List[str]],

14

train_return_nodes: Union[Dict[str, str], List[str]] = None,

15

suppress_diff_warnings: bool = False,

16

tracer_kwargs: Dict[str, Any] = None,

17

**kwargs

18

):

19

"""

20

Create a feature extractor from a model using FX graph tracing.

21

22

Args:

23

model: Source model to extract features from

24

return_nodes: Nodes to return features from. Can be dict mapping

25

node names to output names, or list of node names

26

train_return_nodes: Different nodes for training mode

27

suppress_diff_warnings: Suppress warnings about model differences

28

tracer_kwargs: Additional arguments for FX tracer

29

**kwargs: Additional arguments

30

31

Returns:

32

Feature extractor model that returns specified intermediate features

33

"""

34

35

def get_graph_node_names(

36

model: torch.nn.Module,

37

tracer_kwargs: Dict[str, Any] = None,

38

suppress_diff_warnings: bool = False

39

) -> Tuple[List[str], List[str]]:

40

"""

41

Get node names from model's FX graph for feature extraction.

42

43

Args:

44

model: Model to analyze

45

tracer_kwargs: Additional tracer arguments

46

suppress_diff_warnings: Suppress model difference warnings

47

48

Returns:

49

Tuple of (node_names, node_types) for available extraction points

50

"""

51

```

52

53

## Feature Extraction Classes

54

55

### Hook-Based Feature Extraction

56

57

```python { .api }

58

class FeatureInfo:

59

"""

60

Information about extracted features.

61

62

Args:

63

feature_info: List of feature information dictionaries

64

out_indices: Output indices for features

65

"""

66

67

def __init__(

68

self,

69

feature_info: List[Dict[str, Any]],

70

out_indices: List[int]

71

): ...

72

73

def get_dicts(self, keys: List[str] = None) -> List[Dict[str, Any]]:

74

"""Get feature info as list of dictionaries."""

75

76

def channels(self, idx: int = None) -> Union[List[int], int]:

77

"""Get feature channels."""

78

79

def reduction(self, idx: int = None) -> Union[List[int], int]:

80

"""Get feature reduction factors."""

81

82

class FeatureHooks:

83

"""

84

Feature extraction using forward hooks.

85

86

Args:

87

hooks: List of hook functions

88

named_modules: Dictionary of named modules

89

out_map: Output mapping for feature names

90

"""

91

92

def __init__(

93

self,

94

hooks: List[Callable],

95

named_modules: Dict[str, torch.nn.Module],

96

out_map: List[int] = None

97

): ...

98

99

def get_output(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:

100

"""Get hooked features from forward pass."""

101

102

class FeatureHookNet(torch.nn.Module):

103

"""

104

Wrapper that uses hooks to extract features during forward pass.

105

106

Args:

107

model: Base model to wrap

108

out_indices: Indices of layers to extract features from

109

out_map: Mapping of output names

110

return_interm: Return intermediate features

111

**kwargs: Additional arguments

112

"""

113

114

def __init__(

115

self,

116

model: torch.nn.Module,

117

out_indices: List[int],

118

out_map: List[str] = None,

119

return_interm: bool = False,

120

**kwargs

121

): ...

122

123

class FeatureListNet(torch.nn.Module):

124

"""

125

Wrapper that returns features as a list.

126

127

Args:

128

model: Base model to wrap

129

out_indices: Indices of layers to extract features from

130

**kwargs: Additional arguments

131

"""

132

133

def __init__(

134

self,

135

model: torch.nn.Module,

136

out_indices: List[int],

137

**kwargs

138

): ...

139

140

class FeatureDictNet(torch.nn.Module):

141

"""

142

Wrapper that returns features as a dictionary.

143

144

Args:

145

model: Base model to wrap

146

out_indices: Indices of layers to extract features from

147

out_map: Names for output features

148

**kwargs: Additional arguments

149

"""

150

151

def __init__(

152

self,

153

model: torch.nn.Module,

154

out_indices: List[int],

155

out_map: List[str] = None,

156

**kwargs

157

): ...

158

```

159

160

### FX-Based Feature Extraction

161

162

```python { .api }

163

class FeatureGraphNet(torch.nn.Module):

164

"""

165

FX-based feature extraction network.

166

167

Args:

168

model: Base model

169

out_indices: Output layer indices

170

out_map: Feature name mapping

171

**kwargs: Additional arguments

172

"""

173

174

def __init__(

175

self,

176

model: torch.nn.Module,

177

out_indices: List[int],

178

out_map: List[str] = None,

179

**kwargs

180

): ...

181

182

class GraphExtractNet(torch.nn.Module):

183

"""

184

Graph-based feature extraction using FX.

185

186

Args:

187

model: Source model

188

return_nodes: Nodes to extract features from

189

**kwargs: Additional arguments

190

"""

191

192

def __init__(

193

self,

194

model: torch.nn.Module,

195

return_nodes: Dict[str, str],

196

**kwargs

197

): ...

198

```

199

200

## Model Manipulation

201

202

### Model Analysis and Modification

203

204

```python { .api }

205

def model_parameters(

206

model: torch.nn.Module,

207

exclude_head: bool = False,

208

recurse: bool = True

209

) -> Iterator[torch.nn.Parameter]:

210

"""

211

Get model parameters with filtering options.

212

213

Args:

214

model: Model to analyze

215

exclude_head: Exclude classifier/head parameters

216

recurse: Recurse into submodules

217

218

Returns:

219

Iterator over model parameters

220

"""

221

222

def named_apply(

223

fn: Callable,

224

module: torch.nn.Module,

225

name: str = '',

226

depth_first: bool = True,

227

include_root: bool = False

228

) -> torch.nn.Module:

229

"""

230

Apply function to named modules recursively.

231

232

Args:

233

fn: Function to apply to each module

234

module: Root module

235

name: Current module name

236

depth_first: Apply depth-first traversal

237

include_root: Include root module

238

239

Returns:

240

Modified module

241

"""

242

243

def named_modules(

244

module: torch.nn.Module,

245

memo: set = None,

246

prefix: str = '',

247

remove_duplicate: bool = True

248

) -> Iterator[Tuple[str, torch.nn.Module]]:

249

"""

250

Get named modules with filtering.

251

252

Args:

253

module: Root module

254

memo: Set for tracking duplicates

255

prefix: Name prefix

256

remove_duplicate: Remove duplicate modules

257

258

Returns:

259

Iterator of (name, module) pairs

260

"""

261

262

def group_modules(

263

module: torch.nn.Module,

264

group_matcher: Callable,

265

output_values: bool = False,

266

reverse: bool = False

267

) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Module]]]:

268

"""

269

Group modules by matching criteria.

270

271

Args:

272

module: Module to group

273

group_matcher: Function to determine group membership

274

output_values: Return module objects instead of names

275

reverse: Reverse the grouping order

276

277

Returns:

278

Dictionary mapping group IDs to module names/objects

279

"""

280

281

def group_parameters(

282

module: torch.nn.Module,

283

group_matcher: Callable,

284

output_values: bool = False,

285

reverse: bool = False

286

) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Parameter]]]:

287

"""

288

Group parameters by matching criteria.

289

290

Args:

291

module: Module to analyze

292

group_matcher: Function to determine group membership

293

output_values: Return parameter objects instead of names

294

reverse: Reverse the grouping order

295

296

Returns:

297

Dictionary mapping group IDs to parameter names/objects

298

"""

299

300

def checkpoint_seq(

301

functions: List[Callable],

302

segments: int = 1,

303

input: torch.Tensor = None,

304

**kwargs

305

) -> torch.Tensor:

306

"""

307

Apply gradient checkpointing to sequence of functions.

308

309

Args:

310

functions: List of functions to apply

311

segments: Number of checkpoint segments

312

input: Input tensor

313

**kwargs: Additional arguments

314

315

Returns:

316

Output tensor with gradient checkpointing applied

317

"""

318

```

319

320

### Model Adaptation

321

322

```python { .api }

323

def adapt_input_conv(

324

model: torch.nn.Module,

325

in_chans: int,

326

conv_layer: str = None

327

) -> torch.nn.Module:

328

"""

329

Adapt model's input convolution for different channel counts.

330

331

Args:

332

model: Model to adapt

333

in_chans: New number of input channels

334

conv_layer: Name of convolution layer to adapt

335

336

Returns:

337

Model with adapted input convolution

338

"""

339

340

def load_pretrained(

341

model: torch.nn.Module,

342

cfg: Dict[str, Any] = None,

343

num_classes: int = 1000,

344

in_chans: int = 3,

345

filter_fn: Callable = None,

346

strict: bool = True,

347

progress: bool = False

348

) -> None:

349

"""

350

Load pretrained weights into model.

351

352

Args:

353

model: Model to load weights into

354

cfg: Pretrained configuration

355

num_classes: Number of output classes

356

in_chans: Number of input channels

357

filter_fn: Function to filter state dict keys

358

strict: Strict loading mode

359

progress: Show download progress

360

"""

361

362

def load_custom_pretrained(

363

model: torch.nn.Module,

364

cfg: Dict[str, Any] = None,

365

load_fn: Callable = None,

366

progress: bool = False,

367

check_hash: bool = False

368

) -> None:

369

"""

370

Load custom pretrained weights.

371

372

Args:

373

model: Model to load weights into

374

cfg: Configuration dictionary

375

load_fn: Custom loading function

376

progress: Show progress

377

check_hash: Verify file hash

378

"""

379

380

def build_model_with_cfg(

381

model_cls: Callable,

382

variant: str,

383

pretrained: bool,

384

pretrained_cfg: Dict[str, Any],

385

model_cfg: Dict[str, Any],

386

feature_cfg: Dict[str, Any],

387

**kwargs

388

) -> torch.nn.Module:

389

"""

390

Build model with configuration.

391

392

Args:

393

model_cls: Model class constructor

394

variant: Model variant name

395

pretrained: Load pretrained weights

396

pretrained_cfg: Pretrained configuration

397

model_cfg: Model configuration

398

feature_cfg: Feature extraction configuration

399

**kwargs: Additional model arguments

400

401

Returns:

402

Configured model instance

403

"""

404

```

405

406

## State Dictionary Utilities

407

408

### State Dict Manipulation

409

410

```python { .api }

411

def clean_state_dict(

412

state_dict: Dict[str, Any],

413

model: torch.nn.Module = None

414

) -> Dict[str, Any]:

415

"""

416

Clean state dictionary by removing unwanted keys.

417

418

Args:

419

state_dict: State dictionary to clean

420

model: Model to match against

421

422

Returns:

423

Cleaned state dictionary

424

"""

425

426

def load_state_dict(

427

checkpoint_path: str,

428

use_ema: bool = True,

429

device: torch.device = 'cpu'

430

) -> Dict[str, Any]:

431

"""

432

Load state dictionary from checkpoint file.

433

434

Args:

435

checkpoint_path: Path to checkpoint file

436

use_ema: Use EMA weights if available

437

device: Device to load tensors on

438

439

Returns:

440

Loaded state dictionary

441

"""

442

443

def load_checkpoint(

444

model: torch.nn.Module,

445

checkpoint_path: str,

446

use_ema: bool = False,

447

device: torch.device = 'cpu',

448

strict: bool = True

449

) -> None:

450

"""

451

Load checkpoint into model.

452

453

Args:

454

model: Model to load checkpoint into

455

checkpoint_path: Path to checkpoint file

456

use_ema: Use EMA weights if available

457

device: Device for loading

458

strict: Strict loading mode

459

"""

460

461

def remap_state_dict(

462

state_dict: Dict[str, Any],

463

remap_dict: Dict[str, str]

464

) -> Dict[str, Any]:

465

"""

466

Remap state dictionary keys using mapping rules.

467

468

Args:

469

state_dict: Original state dictionary

470

remap_dict: Mapping from old keys to new keys

471

472

Returns:

473

Remapped state dictionary

474

"""

475

476

def resume_checkpoint(

477

model: torch.nn.Module,

478

checkpoint_path: str,

479

optimizer: torch.optim.Optimizer = None,

480

loss_scaler = None,

481

log_info: bool = True

482

) -> Dict[str, Any]:

483

"""

484

Resume training from checkpoint.

485

486

Args:

487

model: Model to resume

488

checkpoint_path: Path to checkpoint

489

optimizer: Optimizer to resume

490

loss_scaler: Loss scaler to resume

491

log_info: Log resume information

492

493

Returns:

494

Dictionary with resume information

495

"""

496

```

497

498

## Usage Examples

499

500

### Basic Feature Extraction

501

502

```python

503

import timm

504

from timm.models import create_feature_extractor

505

506

# Create a model

507

model = timm.create_model('resnet50', pretrained=True)

508

509

# Create feature extractor for specific layers

510

feature_extractor = create_feature_extractor(

511

model,

512

return_nodes={

513

'layer1': 'feat1',

514

'layer2': 'feat2',

515

'layer3': 'feat3',

516

'layer4': 'feat4'

517

}

518

)

519

520

# Extract features

521

import torch

522

x = torch.randn(1, 3, 224, 224)

523

features = feature_extractor(x)

524

print(f"Feature shapes: {[(k, v.shape) for k, v in features.items()]}")

525

```

526

527

### Hook-Based Feature Extraction

528

529

```python

530

from timm.models import FeatureListNet

531

532

# Create model that returns features as list

533

model = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)

534

535

# Or wrap existing model

536

base_model = timm.create_model('resnet34', pretrained=True)

537

feature_model = FeatureListNet(base_model, out_indices=[1, 2, 3, 4])

538

539

# Extract features

540

features = feature_model(x)

541

print(f"Number of feature maps: {len(features)}")

542

for i, feat in enumerate(features):

543

print(f"Feature {i}: {feat.shape}")

544

```

545

546

### Model Analysis

547

548

```python

549

from timm.models import get_graph_node_names, model_parameters

550

551

# Analyze model structure

552

model = timm.create_model('vit_base_patch16_224', pretrained=True)

553

554

# Get available nodes for feature extraction

555

node_names, node_types = get_graph_node_names(model)

556

print(f"Available nodes: {len(node_names)}")

557

print(f"Sample nodes: {node_names[:10]}")

558

559

# Count parameters

560

total_params = sum(p.numel() for p in model_parameters(model))

561

print(f"Total parameters: {total_params:,}")

562

563

# Count parameters excluding head

564

body_params = sum(p.numel() for p in model_parameters(model, exclude_head=True))

565

print(f"Body parameters: {body_params:,}")

566

```

567

568

### Model Adaptation

569

570

```python

571

from timm.models import adapt_input_conv, load_checkpoint

572

573

# Adapt model for different input channels (e.g., grayscale)

574

model = timm.create_model('resnet50', pretrained=True)

575

model = adapt_input_conv(model, in_chans=1)

576

577

# Load custom checkpoint

578

load_checkpoint(model, 'path/to/checkpoint.pth')

579

580

# Resume training

581

checkpoint_info = resume_checkpoint(

582

model,

583

'path/to/checkpoint.pth',

584

optimizer=optimizer,

585

log_info=True

586

)

587

start_epoch = checkpoint_info['epoch']

588

```

589

590

### Advanced Feature Configuration

591

592

```python

593

# Create model with specific feature configuration

594

model = timm.create_model(

595

'resnet50',

596

pretrained=True,

597

features_only=True,

598

out_indices=[1, 2, 3, 4], # Which stages to output

599

output_stride=16, # Overall output stride

600

global_pool='', # Disable global pooling

601

num_classes=0 # Remove classifier

602

)

603

604

# Get feature info

605

feature_info = model.feature_info.get_dicts()

606

for info in feature_info:

607

print(f"Layer: {info['module']}, Channels: {info['num_chs']}, Reduction: {info['reduction']}")

608

```

609

610

## Types

611

612

```python { .api }

613

from typing import Optional, Union, List, Dict, Callable, Any, Tuple, Iterator

614

import torch

615

616

# Feature extraction types

617

FeatureDict = Dict[str, torch.Tensor]

618

FeatureList = List[torch.Tensor]

619

NodeSpec = Union[Dict[str, str], List[str]]

620

621

# Model analysis types

622

ParameterIterator = Iterator[torch.nn.Parameter]

623

ModuleDict = Dict[str, torch.nn.Module]

624

ParameterDict = Dict[str, torch.nn.Parameter]

625

626

# State dict types

627

StateDict = Dict[str, Any]

628

RemapDict = Dict[str, str]

629

630

# Hook types

631

HookFunction = Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], None]

632

FilterFunction = Callable[[str, torch.nn.Parameter], bool]

633

```