or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

feature-extraction.mdgeneration.mdindex.mdmodels.mdoptimization.mdpipelines.mdtokenization.mdtraining.md

models.mddocs/

0

# Models

1

2

Comprehensive model management with automatic selection and loading for 350+ architectures. The model system provides consistent APIs across text, vision, audio, and multimodal domains while supporting PyTorch, TensorFlow, and JAX frameworks.

3

4

## Capabilities

5

6

### Auto Model Classes

7

8

Automatic model selection based on model names or configurations, eliminating the need to know specific architecture classes.

9

10

```python { .api }

11

class AutoModel:

12

@classmethod

13

def from_pretrained(

14

cls,

15

pretrained_model_name_or_path: Union[str, os.PathLike],

16

*model_args,

17

config: PretrainedConfig = None,

18

cache_dir: Union[str, os.PathLike] = None,

19

ignore_mismatched_sizes: bool = False,

20

force_download: bool = False,

21

local_files_only: bool = False,

22

token: Union[bool, str] = None,

23

revision: str = "main",

24

use_safetensors: bool = None,

25

**kwargs

26

) -> PreTrainedModel:

27

"""

28

Load a pretrained model automatically detecting the architecture.

29

30

Args:

31

pretrained_model_name_or_path: Model name or local path

32

config: Model configuration (auto-detected if None)

33

cache_dir: Custom cache directory

34

ignore_mismatched_sizes: Ignore size mismatches when loading

35

force_download: Force fresh download

36

local_files_only: Only use local files

37

token: Hugging Face authentication token

38

revision: Model revision/branch

39

use_safetensors: Use safetensors format when available

40

41

Returns:

42

Loaded model instance

43

"""

44

```

45

46

### Task-Specific Auto Models

47

48

Pre-configured models for common tasks with appropriate heads and loss functions.

49

50

```python { .api }

51

class AutoModelForSequenceClassification:

52

@classmethod

53

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

54

"""Load model for sequence classification tasks."""

55

56

class AutoModelForTokenClassification:

57

@classmethod

58

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

59

"""Load model for token classification (NER, POS tagging)."""

60

61

class AutoModelForQuestionAnswering:

62

@classmethod

63

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

64

"""Load model for extractive question answering."""

65

66

class AutoModelForMaskedLM:

67

@classmethod

68

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

69

"""Load model for masked language modeling."""

70

71

class AutoModelForCausalLM:

72

@classmethod

73

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

74

"""Load model for causal language modeling (text generation)."""

75

76

class AutoModelForSeq2SeqLM:

77

@classmethod

78

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

79

"""Load model for sequence-to-sequence tasks."""

80

81

class AutoModelForImageClassification:

82

@classmethod

83

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

84

"""Load model for image classification."""

85

86

class AutoModelForObjectDetection:

87

@classmethod

88

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:

89

"""Load model for object detection."""

90

```

91

92

Usage examples:

93

```python

94

# Text classification

95

model = AutoModelForSequenceClassification.from_pretrained(

96

"bert-base-uncased",

97

num_labels=3

98

)

99

100

# Text generation

101

model = AutoModelForCausalLM.from_pretrained("gpt2")

102

103

# Image classification

104

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

105

```

106

107

### Base Model Classes

108

109

Foundation classes that all specific model implementations inherit from.

110

111

```python { .api }

112

class PreTrainedModel:

113

"""Base class for all PyTorch models."""

114

115

def __init__(self, config: PretrainedConfig, *inputs, **kwargs)

116

117

@classmethod

118

def from_pretrained(

119

cls,

120

pretrained_model_name_or_path: Union[str, os.PathLike],

121

**kwargs

122

) -> 'PreTrainedModel':

123

"""Load pretrained model weights and configuration."""

124

125

def save_pretrained(

126

self,

127

save_directory: Union[str, os.PathLike],

128

is_main_process: bool = True,

129

state_dict: Dict[str, torch.Tensor] = None,

130

save_function: Callable = None,

131

push_to_hub: bool = False,

132

max_shard_size: Union[int, str] = "5GB",

133

safe_serialization: bool = True,

134

**kwargs

135

) -> None:

136

"""Save model weights and configuration."""

137

138

def push_to_hub(

139

self,

140

repo_id: str,

141

use_temp_dir: bool = None,

142

commit_message: str = None,

143

private: bool = None,

144

token: Union[bool, str] = None,

145

**kwargs

146

) -> str:

147

"""Upload model to Hugging Face Hub."""

148

149

def forward(self, **kwargs) -> Union[torch.Tensor, ModelOutput]:

150

"""Forward pass through the model."""

151

152

def generate(self, **kwargs) -> torch.Tensor:

153

"""Generate sequences (available on generative models)."""

154

155

def resize_token_embeddings(

156

self,

157

new_num_tokens: int = None

158

) -> torch.nn.Embedding:

159

"""Resize input token embeddings matrix."""

160

161

def get_input_embeddings(self) -> torch.nn.Module:

162

"""Get input embeddings layer."""

163

164

def set_input_embeddings(self, value: torch.nn.Module) -> None:

165

"""Set input embeddings layer."""

166

167

def tie_weights(self) -> None:

168

"""Tie input and output embeddings if specified in config."""

169

170

def gradient_checkpointing_enable(self) -> None:

171

"""Enable gradient checkpointing for training."""

172

173

def gradient_checkpointing_disable(self) -> None:

174

"""Disable gradient checkpointing."""

175

```

176

177

### TensorFlow Models

178

179

TensorFlow implementations of all model architectures with Keras compatibility.

180

181

```python { .api }

182

class TFPreTrainedModel:

183

"""Base class for all TensorFlow models."""

184

185

@classmethod

186

def from_pretrained(

187

cls,

188

pretrained_model_name_or_path: Union[str, os.PathLike],

189

**kwargs

190

) -> 'TFPreTrainedModel':

191

"""Load pretrained TensorFlow model."""

192

193

def save_pretrained(

194

self,

195

save_directory: Union[str, os.PathLike],

196

**kwargs

197

) -> None:

198

"""Save TensorFlow model."""

199

200

def call(self, **kwargs) -> Union[tf.Tensor, TFModelOutput]:

201

"""Forward pass through TensorFlow model."""

202

203

# Task-specific TF models

204

class TFAutoModel:

205

@classmethod

206

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModel

207

208

class TFAutoModelForSequenceClassification:

209

@classmethod

210

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModel

211

212

class TFAutoModelForCausalLM:

213

@classmethod

214

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModel

215

```

216

217

### Flax/JAX Models

218

219

JAX implementations with Flax for high-performance training and inference.

220

221

```python { .api }

222

class FlaxPreTrainedModel:

223

"""Base class for all Flax/JAX models."""

224

225

@classmethod

226

def from_pretrained(

227

cls,

228

pretrained_model_name_or_path: Union[str, os.PathLike],

229

**kwargs

230

) -> 'FlaxPreTrainedModel':

231

"""Load pretrained Flax model."""

232

233

def save_pretrained(

234

self,

235

save_directory: Union[str, os.PathLike],

236

**kwargs

237

) -> None:

238

"""Save Flax model."""

239

240

def __call__(self, **kwargs) -> Union[jnp.ndarray, FlaxModelOutput]:

241

"""Forward pass through Flax model."""

242

243

# Task-specific Flax models

244

class FlaxAutoModel:

245

@classmethod

246

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> FlaxPreTrainedModel

247

248

class FlaxAutoModelForCausalLM:

249

@classmethod

250

def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> FlaxPreTrainedModel

251

```

252

253

### Model Configurations

254

255

Configuration classes that define model architectures and hyperparameters.

256

257

```python { .api }

258

class AutoConfig:

259

@classmethod

260

def from_pretrained(

261

cls,

262

pretrained_model_name_or_path: Union[str, os.PathLike],

263

**kwargs

264

) -> PretrainedConfig:

265

"""Load model configuration automatically."""

266

267

class PretrainedConfig:

268

"""Base configuration class for all models."""

269

270

def __init__(self, **kwargs)

271

272

@classmethod

273

def from_pretrained(

274

cls,

275

pretrained_model_name_or_path: Union[str, os.PathLike],

276

**kwargs

277

) -> 'PretrainedConfig':

278

"""Load configuration from pretrained model."""

279

280

def save_pretrained(

281

self,

282

save_directory: Union[str, os.PathLike],

283

push_to_hub: bool = False,

284

**kwargs

285

) -> None:

286

"""Save configuration to directory."""

287

288

def to_dict(self) -> Dict[str, Any]:

289

"""Convert configuration to dictionary."""

290

291

def to_json_file(self, json_file_path: Union[str, os.PathLike]) -> None:

292

"""Save configuration to JSON file."""

293

```

294

295

### Popular Model Architectures

296

297

#### BERT Family

298

```python { .api }

299

class BertModel(PreTrainedModel):

300

"""BERT model for encoding tasks."""

301

302

class BertForSequenceClassification(PreTrainedModel):

303

"""BERT model with sequence classification head."""

304

305

class BertForTokenClassification(PreTrainedModel):

306

"""BERT model with token classification head."""

307

308

class BertForQuestionAnswering(PreTrainedModel):

309

"""BERT model with question answering head."""

310

311

class BertForMaskedLM(PreTrainedModel):

312

"""BERT model with masked language modeling head."""

313

```

314

315

#### GPT Family

316

```python { .api }

317

class GPT2Model(PreTrainedModel):

318

"""GPT-2 model for generation tasks."""

319

320

class GPT2LMHeadModel(PreTrainedModel):

321

"""GPT-2 model with language modeling head."""

322

323

class GPTNeoModel(PreTrainedModel):

324

"""GPT-Neo model architecture."""

325

326

class GPTNeoXModel(PreTrainedModel):

327

"""GPT-NeoX model architecture."""

328

329

class GPTJModel(PreTrainedModel):

330

"""GPT-J model architecture."""

331

```

332

333

#### T5 Family

334

```python { .api }

335

class T5Model(PreTrainedModel):

336

"""T5 encoder-decoder model."""

337

338

class T5ForConditionalGeneration(PreTrainedModel):

339

"""T5 model with conditional generation head."""

340

341

class T5EncoderModel(PreTrainedModel):

342

"""T5 encoder-only model."""

343

```

344

345

#### Vision Models

346

```python { .api }

347

class ViTModel(PreTrainedModel):

348

"""Vision Transformer model."""

349

350

class ViTForImageClassification(PreTrainedModel):

351

"""ViT model with image classification head."""

352

353

class DetrModel(PreTrainedModel):

354

"""DETR object detection model."""

355

356

class DetrForObjectDetection(PreTrainedModel):

357

"""DETR model with object detection head."""

358

```

359

360

#### Multimodal Models

361

```python { .api }

362

class CLIPModel(PreTrainedModel):

363

"""CLIP vision-language model."""

364

365

class CLIPTextModel(PreTrainedModel):

366

"""CLIP text encoder."""

367

368

class CLIPVisionModel(PreTrainedModel):

369

"""CLIP vision encoder."""

370

371

class BlipModel(PreTrainedModel):

372

"""BLIP multimodal model."""

373

374

class BlipForConditionalGeneration(PreTrainedModel):

375

"""BLIP model with conditional generation."""

376

```

377

378

## Model Output Types

379

380

Standard output formats for different model types:

381

382

```python { .api }

383

class BaseModelOutput:

384

"""Base output type for encoder models."""

385

last_hidden_state: torch.Tensor

386

hidden_states: Optional[Tuple[torch.Tensor]] = None

387

attentions: Optional[Tuple[torch.Tensor]] = None

388

389

class BaseModelOutputWithPooling:

390

"""Base output with pooling for classification models."""

391

last_hidden_state: torch.Tensor

392

pooler_output: torch.Tensor

393

hidden_states: Optional[Tuple[torch.Tensor]] = None

394

attentions: Optional[Tuple[torch.Tensor]] = None

395

396

class CausalLMOutput:

397

"""Output for causal language models."""

398

loss: Optional[torch.Tensor] = None

399

logits: torch.Tensor

400

hidden_states: Optional[Tuple[torch.Tensor]] = None

401

attentions: Optional[Tuple[torch.Tensor]] = None

402

403

class SequenceClassifierOutput:

404

"""Output for sequence classification models."""

405

loss: Optional[torch.Tensor] = None

406

logits: torch.Tensor

407

hidden_states: Optional[Tuple[torch.Tensor]] = None

408

attentions: Optional[Tuple[torch.Tensor]] = None

409

410

class TokenClassifierOutput:

411

"""Output for token classification models."""

412

loss: Optional[torch.Tensor] = None

413

logits: torch.Tensor

414

hidden_states: Optional[Tuple[torch.Tensor]] = None

415

attentions: Optional[Tuple[torch.Tensor]] = None

416

417

class QuestionAnsweringModelOutput:

418

"""Output for question answering models."""

419

loss: Optional[torch.Tensor] = None

420

start_logits: torch.Tensor

421

end_logits: torch.Tensor

422

hidden_states: Optional[Tuple[torch.Tensor]] = None

423

attentions: Optional[Tuple[torch.Tensor]] = None

424

```

425

426

## Loading and Saving Patterns

427

428

Common patterns for working with models:

429

430

```python

431

# Load model with custom configuration

432

config = AutoConfig.from_pretrained("bert-base-uncased")

433

config.num_labels = 3

434

model = AutoModelForSequenceClassification.from_pretrained(

435

"bert-base-uncased",

436

config=config

437

)

438

439

# Load model with custom dtype and device

440

model = AutoModelForCausalLM.from_pretrained(

441

"gpt2",

442

torch_dtype=torch.float16,

443

device_map="auto"

444

)

445

446

# Save model locally

447

model.save_pretrained("./my-model")

448

449

# Upload to Hub

450

model.push_to_hub("username/my-model", private=True)

451

452

# Load from local directory

453

model = AutoModel.from_pretrained("./my-model")

454

```