0
# Auto Classes
1
2
Factory classes that provide automatic model, tokenizer, and configuration selection based on model name patterns. These classes eliminate the need to manually specify which architecture-specific class to use, making it easy to switch between different transformer models.
3
4
## Capabilities
5
6
### AutoTokenizer
7
8
Automatically selects and instantiates the appropriate tokenizer class based on the model name or path. Supports BERT, GPT-2, OpenAI GPT, Transformer-XL, XLNet, XLM, RoBERTa, and DistilBERT tokenizers.
9
10
```python { .api }
11
class AutoTokenizer:
12
@classmethod
13
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
14
"""
15
Instantiate the appropriate tokenizer class from a pre-trained model.
16
17
Parameters:
18
- pretrained_model_name_or_path (str): Model name or local path
19
- cache_dir (str, optional): Directory to cache downloaded files
20
- force_download (bool, optional): Force re-download even if cached
21
- resume_download (bool, optional): Resume incomplete downloads
22
- proxies (dict, optional): HTTP proxy configuration
23
- use_auth_token (str/bool, optional): Authentication token for private models
24
25
Returns:
26
PreTrainedTokenizer: Instance of the appropriate tokenizer class
27
"""
28
```
29
30
**Usage Examples:**
31
32
```python
33
from pytorch_transformers import AutoTokenizer
34
35
# Load BERT tokenizer
36
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
37
38
# Load GPT-2 tokenizer
39
tokenizer = AutoTokenizer.from_pretrained("gpt2")
40
41
# Load from local directory
42
tokenizer = AutoTokenizer.from_pretrained("./my-model")
43
44
# With custom cache directory
45
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", cache_dir="./cache")
46
```
47
48
### AutoConfig
49
50
Automatically selects and loads the appropriate configuration class based on the model name or path. Configurations contain model hyperparameters and architecture specifications.
51
52
```python { .api }
53
class AutoConfig:
54
@classmethod
55
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
56
"""
57
Instantiate the appropriate configuration class from a pre-trained model.
58
59
Parameters:
60
- pretrained_model_name_or_path (str): Model name or local path
61
- cache_dir (str, optional): Directory to cache downloaded files
62
- force_download (bool, optional): Force re-download even if cached
63
- resume_download (bool, optional): Resume incomplete downloads
64
- proxies (dict, optional): HTTP proxy configuration
65
- use_auth_token (str/bool, optional): Authentication token for private models
66
67
Returns:
68
PretrainedConfig: Instance of the appropriate configuration class
69
"""
70
```
71
72
**Usage Examples:**
73
74
```python
75
from pytorch_transformers import AutoConfig
76
77
# Load configuration
78
config = AutoConfig.from_pretrained("bert-base-uncased")
79
80
# Access configuration attributes
81
print(config.hidden_size)
82
print(config.num_attention_heads)
83
print(config.num_hidden_layers)
84
```
85
86
### AutoModel
87
88
Automatically loads the base model class (without task-specific heads) for the specified architecture. Returns models suitable for feature extraction and embedding generation.
89
90
```python { .api }
91
class AutoModel:
92
@classmethod
93
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
94
"""
95
Instantiate the appropriate base model class from a pre-trained model.
96
97
Parameters:
98
- pretrained_model_name_or_path (str): Model name or local path
99
- config (PretrainedConfig, optional): Model configuration
100
- cache_dir (str, optional): Directory to cache downloaded files
101
- from_tf (bool, optional): Load from TensorFlow checkpoint
102
- force_download (bool, optional): Force re-download even if cached
103
- resume_download (bool, optional): Resume incomplete downloads
104
- proxies (dict, optional): HTTP proxy configuration
105
- output_loading_info (bool, optional): Return loading info dict
106
- use_auth_token (str/bool, optional): Authentication token for private models
107
108
Returns:
109
PreTrainedModel: Instance of the appropriate base model class
110
"""
111
```
112
113
### AutoModelWithLMHead
114
115
Automatically loads models with language modeling heads for text generation and language modeling tasks.
116
117
```python { .api }
118
class AutoModelWithLMHead:
119
@classmethod
120
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
121
"""
122
Instantiate the appropriate language modeling model from a pre-trained model.
123
124
Parameters:
125
- pretrained_model_name_or_path (str): Model name or local path
126
- config (PretrainedConfig, optional): Model configuration
127
- cache_dir (str, optional): Directory to cache downloaded files
128
- from_tf (bool, optional): Load from TensorFlow checkpoint
129
- force_download (bool, optional): Force re-download even if cached
130
- resume_download (bool, optional): Resume incomplete downloads
131
- proxies (dict, optional): HTTP proxy configuration
132
- output_loading_info (bool, optional): Return loading info dict
133
- use_auth_token (str/bool, optional): Authentication token for private models
134
135
Returns:
136
PreTrainedModel: Instance of the appropriate LM model class
137
"""
138
```
139
140
### AutoModelForSequenceClassification
141
142
Automatically loads models with sequence classification heads for tasks like sentiment analysis, text classification, and natural language inference.
143
144
```python { .api }
145
class AutoModelForSequenceClassification:
146
@classmethod
147
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
148
"""
149
Instantiate the appropriate sequence classification model from a pre-trained model.
150
151
Parameters:
152
- pretrained_model_name_or_path (str): Model name or local path
153
- config (PretrainedConfig, optional): Model configuration
154
- num_labels (int, optional): Number of classification labels
155
- cache_dir (str, optional): Directory to cache downloaded files
156
- from_tf (bool, optional): Load from TensorFlow checkpoint
157
- force_download (bool, optional): Force re-download even if cached
158
- resume_download (bool, optional): Resume incomplete downloads
159
- proxies (dict, optional): HTTP proxy configuration
160
- output_loading_info (bool, optional): Return loading info dict
161
- use_auth_token (str/bool, optional): Authentication token for private models
162
163
Returns:
164
PreTrainedModel: Instance of the appropriate sequence classification model
165
"""
166
```
167
168
### AutoModelForQuestionAnswering
169
170
Automatically loads models with question answering heads for extractive question answering tasks like SQuAD.
171
172
```python { .api }
173
class AutoModelForQuestionAnswering:
174
@classmethod
175
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
176
"""
177
Instantiate the appropriate question answering model from a pre-trained model.
178
179
Parameters:
180
- pretrained_model_name_or_path (str): Model name or local path
181
- config (PretrainedConfig, optional): Model configuration
182
- cache_dir (str, optional): Directory to cache downloaded files
183
- from_tf (bool, optional): Load from TensorFlow checkpoint
184
- force_download (bool, optional): Force re-download even if cached
185
- resume_download (bool, optional): Resume incomplete downloads
186
- proxies (dict, optional): HTTP proxy configuration
187
- output_loading_info (bool, optional): Return loading info dict
188
- use_auth_token (str/bool, optional): Authentication token for private models
189
190
Returns:
191
PreTrainedModel: Instance of the appropriate QA model class
192
"""
193
```
194
195
**Usage Examples:**
196
197
```python
198
from pytorch_transformers import (
199
AutoModel,
200
AutoModelWithLMHead,
201
AutoModelForSequenceClassification,
202
AutoModelForQuestionAnswering
203
)
204
205
# Load base model for feature extraction
206
model = AutoModel.from_pretrained("bert-base-uncased")
207
208
# Load language model for text generation
209
lm_model = AutoModelWithLMHead.from_pretrained("gpt2")
210
211
# Load sequence classifier
212
classifier = AutoModelForSequenceClassification.from_pretrained(
213
"bert-base-uncased",
214
num_labels=2
215
)
216
217
# Load question answering model
218
qa_model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
219
220
# Use with tokenizer for complete pipeline
221
from pytorch_transformers import AutoTokenizer
222
223
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
224
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
225
outputs = model(**inputs)
226
```
227
228
## Supported Model Names
229
230
The Auto classes support the following pre-trained model names:
231
232
**BERT Models:**
233
- `bert-base-uncased`, `bert-large-uncased`
234
- `bert-base-cased`, `bert-large-cased`
235
- `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`
236
- `bert-base-chinese`
237
238
**GPT-2 Models:**
239
- `gpt2`, `gpt2-medium`, `gpt2-large`, `gpt2-xl`
240
241
**Other Models:**
242
- `openai-gpt`
243
- `transfo-xl-wt103`
244
- `xlnet-base-cased`, `xlnet-large-cased`
245
- `xlm-mlm-en-2048`, `xlm-mlm-100-1280`
246
- `roberta-base`, `roberta-large`
247
- `distilbert-base-uncased`, `distilbert-base-cased`