or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mdfeatures.mdindex.mdlayers.mdmodels.mdtraining.mdutils.md

models.mddocs/

0

# Model Creation and Management

1

2

Comprehensive functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of 1000+ pretrained models across 90+ architectures.

3

4

## Capabilities

5

6

### Model Creation

7

8

Create model instances with extensive configuration options, including pretrained weights, custom number of classes, and architectural modifications.

9

10

```python { .api }

11

def create_model(

12

model_name: str,

13

pretrained: bool = False,

14

pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,

15

pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,

16

checkpoint_path: Optional[Union[str, Path]] = None,

17

cache_dir: Optional[Union[str, Path]] = None,

18

scriptable: Optional[bool] = None,

19

exportable: Optional[bool] = None,

20

no_jit: Optional[bool] = None,

21

**kwargs: Any

22

) -> torch.nn.Module:

23

"""

24

Create a model instance.

25

26

Args:

27

model_name: Name of model to instantiate

28

pretrained: Load pretrained weights if True

29

pretrained_cfg: Pretrained configuration override (dict or cfg name)

30

pretrained_cfg_overlay: Dictionary of config overrides

31

num_classes: Number of output classes (default: 1000)

32

in_chans: Number of input image channels (default: 3)

33

global_pool: Global pooling type override

34

scriptable: Set layer config so model is jit scriptable

35

exportable: Set layer config so model is traceable/ONNX exportable

36

no_jit: Disable jit related set/reset of layer config

37

checkpoint_path: Path to load checkpoint from instead of pretrained weights

38

cache_dir: Cache directory for downloaded pretrained weights

39

**kwargs: Model-specific arguments

40

41

Returns:

42

Instantiated model

43

"""

44

```

45

46

#### Usage Examples

47

48

```python

49

import timm

50

51

# Basic model creation

52

model = timm.create_model('resnet50', pretrained=True)

53

54

# Custom number of classes for fine-tuning

55

model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)

56

57

# Model for feature extraction

58

feature_model = timm.create_model('vit_base_patch16_224', pretrained=True, features_only=True)

59

60

# Model optimized for export

61

export_model = timm.create_model('resnet18', pretrained=True, scriptable=True, exportable=True)

62

63

# Load from custom checkpoint

64

model = timm.create_model('resnet50', checkpoint_path='/path/to/checkpoint.pth')

65

```

66

67

### Model Discovery

68

69

Functions to explore and filter the available model architectures and pretrained weights.

70

71

```python { .api }

72

def list_models(

73

filter: str = '',

74

module: str = '',

75

pretrained: bool = False,

76

exclude_filters: str = '',

77

name_matches_cfg: bool = False,

78

include_tags: bool = True

79

) -> list[str]:

80

"""

81

List available models.

82

83

Args:

84

filter: Wildcard filter string to limit model names

85

module: Specific module/architecture to limit results

86

pretrained: Only models with pretrained weights if True

87

exclude_filters: Exclude models matching these patterns

88

name_matches_cfg: Only models where name matches config

89

include_tags: Include model tags in results

90

91

Returns:

92

List of model names matching criteria

93

"""

94

95

def list_pretrained(filter: str = '') -> list[str]:

96

"""

97

List models with pretrained weights available.

98

99

Args:

100

filter: Wildcard filter for model names

101

102

Returns:

103

List of model names with pretrained weights

104

"""

105

106

def list_modules() -> list[str]:

107

"""

108

List available model modules/architectures.

109

110

Returns:

111

List of module names

112

"""

113

```

114

115

#### Usage Examples

116

117

```python

118

# List all models

119

all_models = timm.list_models()

120

121

# Filter models by architecture

122

resnet_models = timm.list_models('*resnet*')

123

vit_models = timm.list_models('vit_*')

124

125

# Only models with pretrained weights

126

pretrained_models = timm.list_models(pretrained=True)

127

128

# List specific architecture variants

129

efficientnet_pretrained = timm.list_pretrained('efficientnet*')

130

131

# Available model families

132

architectures = timm.list_modules()

133

```

134

135

### Model Validation

136

137

Utilities to validate model names and check availability of pretrained weights.

138

139

```python { .api }

140

def is_model(model_name: str) -> bool:

141

"""

142

Check if model name is valid and available.

143

144

Args:

145

model_name: Name to check

146

147

Returns:

148

True if model exists, False otherwise

149

"""

150

151

def is_model_pretrained(model_name: str) -> bool:

152

"""

153

Check if model has pretrained weights available.

154

155

Args:

156

model_name: Model name to check

157

158

Returns:

159

True if pretrained weights exist, False otherwise

160

"""

161

162

def model_entrypoint(model_name: str) -> Callable:

163

"""

164

Get the entrypoint function for a model.

165

166

Args:

167

model_name: Name of model

168

169

Returns:

170

Model creation function

171

"""

172

```

173

174

### Model Configuration

175

176

Access and retrieve model configuration and metadata.

177

178

```python { .api }

179

def get_pretrained_cfg(model_name: str) -> dict:

180

"""

181

Get pretrained configuration for model.

182

183

Args:

184

model_name: Name of model

185

186

Returns:

187

Dictionary containing model configuration including:

188

- input_size: Expected input dimensions

189

- mean: Normalization mean values

190

- std: Normalization standard deviation values

191

- num_classes: Number of output classes

192

- pool_size: Global pooling output size

193

- crop_pct: Center crop percentage

194

- interpolation: Resize interpolation method

195

- first_conv: Name of first convolutional layer

196

- classifier: Name of classifier layer

197

"""

198

199

def get_pretrained_cfg_value(model_name: str, cfg_key: str):

200

"""

201

Get specific configuration value for pretrained model.

202

203

Args:

204

model_name: Name of model

205

cfg_key: Configuration key to retrieve

206

207

Returns:

208

Configuration value for specified key

209

"""

210

```

211

212

#### Usage Examples

213

214

```python

215

# Get complete model configuration

216

cfg = timm.get_pretrained_cfg('resnet50')

217

print(f"Input size: {cfg['input_size']}")

218

print(f"Mean: {cfg['mean']}")

219

print(f"Std: {cfg['std']}")

220

221

# Get specific configuration values

222

input_size = timm.get_pretrained_cfg_value('efficientnet_b0', 'input_size')

223

crop_pct = timm.get_pretrained_cfg_value('vit_base_patch16_224', 'crop_pct')

224

225

# Validate model availability

226

if timm.is_model('my_custom_model'):

227

model = timm.create_model('my_custom_model')

228

229

# Check for pretrained weights

230

if timm.is_model_pretrained('resnet101'):

231

model = timm.create_model('resnet101', pretrained=True)

232

```

233

234

### Advanced Model Creation

235

236

Advanced patterns for model customization and creation.

237

238

#### Model Factory Functions

239

240

```python { .api }

241

def create_model_from_pretrained(

242

model_name: str,

243

pretrained_cfg: dict = None,

244

**model_kwargs

245

) -> torch.nn.Module:

246

"""

247

Create model using specific pretrained configuration.

248

249

Args:

250

model_name: Name of model to create

251

pretrained_cfg: Custom pretrained configuration

252

**model_kwargs: Additional model arguments

253

254

Returns:

255

Configured model instance

256

"""

257

```

258

259

#### Custom Model Registration

260

261

```python { .api }

262

def register_model(fn: Callable = None, *, name: str = None) -> Callable:

263

"""

264

Register a new model architecture.

265

266

Args:

267

fn: Model creation function

268

name: Optional model name override

269

270

Returns:

271

Decorated function

272

"""

273

```

274

275

#### Usage Examples

276

277

```python

278

# Register custom model

279

@timm.register_model

280

def my_custom_resnet(pretrained=False, **kwargs):

281

# Custom ResNet implementation

282

model = MyCustomResNet(**kwargs)

283

if pretrained:

284

# Load custom pretrained weights

285

pass

286

return model

287

288

# Use registered model

289

custom_model = timm.create_model('my_custom_resnet', pretrained=True)

290

```

291

292

### Hugging Face Hub Integration

293

294

TIMM provides seamless integration with Hugging Face Hub for loading models and configurations.

295

296

```python { .api }

297

def load_model_config_from_hf(model_id: str) -> dict:

298

"""

299

Load model configuration from Hugging Face Hub.

300

301

Args:

302

model_id: Hugging Face model identifier

303

304

Returns:

305

Model configuration dictionary

306

"""

307

308

def load_state_dict_from_hf(model_id: str) -> dict:

309

"""

310

Load model weights from Hugging Face Hub.

311

312

Args:

313

model_id: Hugging Face model identifier

314

315

Returns:

316

Model state dictionary

317

"""

318

```

319

320

#### Hub Model Loading Examples

321

322

```python

323

# Load model from Hugging Face Hub using hf-hub: prefix

324

model = timm.create_model('hf-hub:microsoft/resnet-50', pretrained=True)

325

326

# Load local model using local-dir: prefix

327

model = timm.create_model('local-dir:/path/to/model/folder', pretrained=True)

328

329

# Load specific model revision/branch

330

model = timm.create_model('hf-hub:microsoft/resnet-50@main', pretrained=True)

331

```

332

333

## Model Architecture Categories

334

335

TIMM includes models from the following major categories:

336

337

### Vision Transformers

338

- **ViT**: Vision Transformer variants (Base, Large, Huge)

339

- **DeiT**: Data-efficient Image Transformers

340

- **BEiT**: Bidirectional Encoder representation from Image Transformers

341

- **Swin**: Swin Transformer hierarchical models

342

- **CaiT**: Class-Attention in Image Transformers

343

- **CrossViT**: Cross-Attention Multi-Scale Vision Transformer

344

345

### Convolutional Networks

346

- **ResNet**: ResNet and ResNeXt families

347

- **EfficientNet**: EfficientNet B0-B8 and V2 variants

348

- **ConvNeXt**: Modern ConvNet architectures

349

- **RegNet**: Designing Network Design Spaces

350

- **DenseNet**: Densely Connected Convolutional Networks

351

- **MobileNet**: MobileNetV3 and variants

352

353

### Hybrid Architectures

354

- **ConViT**: Convolutions meet Vision Transformers

355

- **LeViT**: Vision Transformer in ConvNet's Clothing

356

- **CoAtNet**: Convolution and Attention networks

357

- **MaxViT**: Multi-Axis Vision Transformer

358

359

### Specialized Models

360

- **CLIP**: Vision encoders from CLIP models

361

- **BEiT3**: Multimodal foundation models

362

- **EVA**: Enhanced Vision Transformer

363

- **InternViT**: Large-scale vision foundation models

364

365

### Advanced Features

366

367

#### NaFlexViT (Native Flexible Vision Transformers)

368

TIMM supports variable aspect ratio and resolution training/inference through NaFlexViT integration.

369

370

```python

371

# Enable NaFlexViT for supported models

372

model = timm.create_model('vit_base_patch16_224', pretrained=True, use_naflex=True)

373

374

# Models with ROPE support can be loaded in NaFlexViT mode

375

model = timm.create_model('eva_large_patch14_196', pretrained=True, use_naflex=True)

376

```

377

378

#### Forward Intermediates API

379

Extract intermediate features from models during forward pass.

380

381

```python

382

# Enable intermediate feature extraction

383

model = timm.create_model('resnet50', pretrained=True)

384

features = model.forward_intermediates(x, indices=[1, 2, 3, 4])

385

```

386

387

## Types

388

389

```python { .api }

390

from typing import Optional, Union, List, Dict, Callable, Any

391

import torch

392

393

# Model configuration types

394

PretrainedCfg = Dict[str, Any]

395

ModelCfg = Dict[str, Any]

396

397

# Model creation function signature

398

ModelEntrypoint = Callable[..., torch.nn.Module]

399

400

# Filter types for model listing

401

ModelFilter = Union[str, List[str]]

402

```