0
# Base Classes
1
2
Core abstract base classes that define common interfaces shared by all models, tokenizers, and configurations in the pytorch-transformers library. These classes provide essential functionality for loading, saving, and managing pre-trained components.
3
4
## Capabilities
5
6
### PreTrainedModel
7
8
Abstract base class for all transformer models, providing common functionality for model loading, saving, parameter management, and inference.
9
10
```python { .api }
11
class PreTrainedModel:
12
@classmethod
13
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
14
"""
15
Instantiate a pre-trained PyTorch model from a pre-trained model configuration.
16
17
Parameters:
18
- pretrained_model_name_or_path (str): Model name or local path
19
- config (PretrainedConfig, optional): Model configuration
20
- cache_dir (str, optional): Directory to cache downloaded files
21
- from_tf (bool, optional): Load from TensorFlow checkpoint
22
- force_download (bool, optional): Force re-download even if cached
23
- resume_download (bool, optional): Resume incomplete downloads
24
- proxies (dict, optional): HTTP proxy configuration
25
- output_loading_info (bool, optional): Return loading info dict
26
- use_auth_token (str/bool, optional): Authentication token for private models
27
- revision (str, optional): Git branch/tag/commit to use
28
- kwargs: Additional arguments passed to model constructor
29
30
Returns:
31
PreTrainedModel: Instance of the model class
32
"""
33
34
def save_pretrained(self, save_directory):
35
"""
36
Save model weights and configuration to a directory.
37
38
Parameters:
39
- save_directory (str): Directory to save model files
40
"""
41
42
def resize_token_embeddings(self, new_num_tokens=None):
43
"""
44
Resize token embeddings matrix of the model.
45
46
Parameters:
47
- new_num_tokens (int, optional): New vocabulary size
48
49
Returns:
50
torch.nn.Embedding: New embeddings matrix
51
"""
52
53
def prune_heads(self, heads_to_prune):
54
"""
55
Prune attention heads in the model.
56
57
Parameters:
58
- heads_to_prune (dict): Dictionary mapping layer to heads to prune
59
"""
60
61
def get_input_embeddings(self):
62
"""
63
Get the model's input embeddings.
64
65
Returns:
66
torch.nn.Module: Input embeddings layer
67
"""
68
69
def set_input_embeddings(self, value):
70
"""
71
Set the model's input embeddings.
72
73
Parameters:
74
- value (torch.nn.Module): New input embeddings layer
75
"""
76
77
def get_output_embeddings(self):
78
"""
79
Get the model's output embeddings.
80
81
Returns:
82
torch.nn.Module: Output embeddings layer
83
"""
84
85
def set_output_embeddings(self, new_embeddings):
86
"""
87
Set the model's output embeddings.
88
89
Parameters:
90
- new_embeddings (torch.nn.Module): New output embeddings layer
91
"""
92
```
93
94
**Usage Examples:**
95
96
```python
97
from pytorch_transformers import BertModel
98
import torch
99
100
# Load pre-trained model
101
model = BertModel.from_pretrained("bert-base-uncased")
102
103
# Save model
104
model.save_pretrained("./my-bert-model")
105
106
# Resize embeddings for new vocabulary
107
model.resize_token_embeddings(30000)
108
109
# Prune attention heads
110
heads_to_prune = {0: [0, 1], 1: [0]} # Prune heads 0,1 in layer 0 and head 0 in layer 1
111
model.prune_heads(heads_to_prune)
112
113
# Access embeddings
114
input_embeddings = model.get_input_embeddings()
115
print(f"Embedding dimensions: {input_embeddings.weight.shape}")
116
117
# Model inference
118
inputs = torch.randint(0, 1000, (1, 10)) # Random token IDs
119
outputs = model(inputs)
120
```
121
122
### PreTrainedTokenizer
123
124
Abstract base class for all tokenizers, providing common tokenization interface and special token handling.
125
126
```python { .api }
127
class PreTrainedTokenizer:
128
@classmethod
129
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
130
"""
131
Instantiate a pre-trained tokenizer from a vocabulary file.
132
133
Parameters:
134
- pretrained_model_name_or_path (str): Model name or local path
135
- cache_dir (str, optional): Directory to cache downloaded files
136
- force_download (bool, optional): Force re-download even if cached
137
- resume_download (bool, optional): Resume incomplete downloads
138
- proxies (dict, optional): HTTP proxy configuration
139
- use_auth_token (str/bool, optional): Authentication token for private models
140
- revision (str, optional): Git branch/tag/commit to use
141
- kwargs: Additional arguments passed to tokenizer constructor
142
143
Returns:
144
PreTrainedTokenizer: Instance of the tokenizer class
145
"""
146
147
def save_pretrained(self, save_directory):
148
"""
149
Save tokenizer vocabulary and configuration to a directory.
150
151
Parameters:
152
- save_directory (str): Directory to save tokenizer files
153
"""
154
155
def tokenize(self, text, **kwargs):
156
"""
157
Tokenize a string into a list of tokens.
158
159
Parameters:
160
- text (str): Input text to tokenize
161
- kwargs: Additional tokenization arguments
162
163
Returns:
164
List[str]: List of tokens
165
"""
166
167
def encode(self, text, text_pair=None, add_special_tokens=True, max_length=None, **kwargs):
168
"""
169
Encode text into token IDs.
170
171
Parameters:
172
- text (str): Primary input text
173
- text_pair (str, optional): Secondary input text for sentence pairs
174
- add_special_tokens (bool): Whether to add special tokens
175
- max_length (int, optional): Maximum sequence length
176
- kwargs: Additional encoding arguments
177
178
Returns:
179
List[int]: List of token IDs
180
"""
181
182
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
183
"""
184
Decode token IDs back to text.
185
186
Parameters:
187
- token_ids (List[int]): Token IDs to decode
188
- skip_special_tokens (bool): Whether to remove special tokens
189
- clean_up_tokenization_spaces (bool): Whether to clean up spaces
190
191
Returns:
192
str: Decoded text
193
"""
194
195
def convert_tokens_to_ids(self, tokens):
196
"""
197
Convert tokens to token IDs.
198
199
Parameters:
200
- tokens (List[str] or str): Token(s) to convert
201
202
Returns:
203
List[int] or int: Token ID(s)
204
"""
205
206
def convert_ids_to_tokens(self, ids):
207
"""
208
Convert token IDs to tokens.
209
210
Parameters:
211
- ids (List[int] or int): Token ID(s) to convert
212
213
Returns:
214
List[str] or str: Token(s)
215
"""
216
217
def __call__(self, text, text_pair=None, **kwargs):
218
"""
219
Main tokenization method with tensor output support.
220
221
Parameters:
222
- text (str or List[str]): Input text(s)
223
- text_pair (str or List[str], optional): Pair text(s)
224
- return_tensors (str, optional): Type of tensors to return ('pt', 'tf', 'np')
225
- padding (bool/str, optional): Padding strategy
226
- truncation (bool/str, optional): Truncation strategy
227
- max_length (int, optional): Maximum sequence length
228
- kwargs: Additional arguments
229
230
Returns:
231
Dict: Dictionary containing input_ids, attention_mask, etc.
232
"""
233
```
234
235
**Special Token Properties:**
236
237
```python { .api }
238
# Special tokens available on all tokenizers
239
bos_token: str # Beginning of sequence token
240
eos_token: str # End of sequence token
241
unk_token: str # Unknown token
242
sep_token: str # Separator token
243
pad_token: str # Padding token
244
cls_token: str # Classification token
245
mask_token: str # Mask token for masked language modeling
246
247
# Special token IDs
248
bos_token_id: int
249
eos_token_id: int
250
unk_token_id: int
251
sep_token_id: int
252
pad_token_id: int
253
cls_token_id: int
254
mask_token_id: int
255
256
# Vocabulary size
257
vocab_size: int
258
```
259
260
**Usage Examples:**
261
262
```python
263
from pytorch_transformers import BertTokenizer
264
265
# Load tokenizer
266
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
267
268
# Basic tokenization
269
text = "Hello, how are you?"
270
tokens = tokenizer.tokenize(text)
271
print(f"Tokens: {tokens}")
272
273
# Encoding to IDs
274
token_ids = tokenizer.encode(text)
275
print(f"Token IDs: {token_ids}")
276
277
# Decoding back to text
278
decoded = tokenizer.decode(token_ids)
279
print(f"Decoded: {decoded}")
280
281
# Full preprocessing with tensors
282
inputs = tokenizer(
283
text,
284
return_tensors="pt",
285
padding=True,
286
truncation=True,
287
max_length=512
288
)
289
print(f"Input shape: {inputs['input_ids'].shape}")
290
291
# Access special tokens
292
print(f"CLS token: {tokenizer.cls_token}")
293
print(f"SEP token: {tokenizer.sep_token}")
294
print(f"PAD token ID: {tokenizer.pad_token_id}")
295
296
# Save tokenizer
297
tokenizer.save_pretrained("./my-tokenizer")
298
```
299
300
### PretrainedConfig
301
302
Base configuration class for all model configurations, containing model hyperparameters and architecture specifications.
303
304
```python { .api }
305
class PretrainedConfig:
306
@classmethod
307
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
308
"""
309
Instantiate a PretrainedConfig from a pre-trained model configuration.
310
311
Parameters:
312
- pretrained_model_name_or_path (str): Model name or local path
313
- cache_dir (str, optional): Directory to cache downloaded files
314
- force_download (bool, optional): Force re-download even if cached
315
- resume_download (bool, optional): Resume incomplete downloads
316
- proxies (dict, optional): HTTP proxy configuration
317
- use_auth_token (str/bool, optional): Authentication token for private models
318
- revision (str, optional): Git branch/tag/commit to use
319
- kwargs: Additional configuration parameters
320
321
Returns:
322
PretrainedConfig: Instance of the configuration class
323
"""
324
325
def save_pretrained(self, save_directory):
326
"""
327
Save configuration to a directory.
328
329
Parameters:
330
- save_directory (str): Directory to save configuration file
331
"""
332
333
def to_dict(self):
334
"""
335
Serialize configuration to a Python dictionary.
336
337
Returns:
338
Dict: Configuration as dictionary
339
"""
340
341
def to_json_string(self):
342
"""
343
Serialize configuration to a JSON string.
344
345
Returns:
346
str: Configuration as JSON string
347
"""
348
349
@classmethod
350
def from_dict(cls, config_dict, **kwargs):
351
"""
352
Construct configuration from a dictionary.
353
354
Parameters:
355
- config_dict (Dict): Configuration dictionary
356
- kwargs: Additional parameters
357
358
Returns:
359
PretrainedConfig: Configuration instance
360
"""
361
362
@classmethod
363
def from_json_file(cls, json_file):
364
"""
365
Construct configuration from a JSON file.
366
367
Parameters:
368
- json_file (str): Path to JSON configuration file
369
370
Returns:
371
PretrainedConfig: Configuration instance
372
"""
373
```
374
375
**Usage Examples:**
376
377
```python
378
from pytorch_transformers import BertConfig
379
380
# Load configuration
381
config = BertConfig.from_pretrained("bert-base-uncased")
382
383
# Access configuration parameters
384
print(f"Hidden size: {config.hidden_size}")
385
print(f"Number of layers: {config.num_hidden_layers}")
386
print(f"Number of attention heads: {config.num_attention_heads}")
387
388
# Modify configuration
389
config.num_labels = 3 # For classification with 3 classes
390
391
# Save configuration
392
config.save_pretrained("./my-config")
393
394
# Convert to dictionary/JSON
395
config_dict = config.to_dict()
396
config_json = config.to_json_string()
397
398
# Create from dictionary
399
custom_config = BertConfig.from_dict({
400
"hidden_size": 512,
401
"num_hidden_layers": 6,
402
"num_attention_heads": 8
403
})
404
```
405
406
### Model Utilities
407
408
Core utility classes and functions for model parameter management and weight manipulation.
409
410
#### Conv1D
411
412
A 1D convolution layer implementation as used in GPT models, where weights are transposed compared to standard linear layers.
413
414
```python { .api }
415
class Conv1D(nn.Module):
416
def __init__(self, nf, nx):
417
"""
418
Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
419
Basically works like a Linear layer but the weights are transposed.
420
421
Parameters:
422
- nf (int): Size of output features
423
- nx (int): Size of input features
424
"""
425
426
def forward(self, x):
427
"""
428
Forward pass through the Conv1D layer.
429
430
Parameters:
431
- x (torch.Tensor): Input tensor
432
433
Returns:
434
torch.Tensor: Output tensor
435
"""
436
```
437
438
#### Layer Pruning
439
440
Utility functions for pruning model layers to remove attention heads or reduce model size.
441
442
```python { .api }
443
def prune_layer(layer, index, dim=None):
444
"""
445
Prune a Conv1D or nn.Linear layer to keep only entries in index.
446
Return the pruned layer as a new layer with requires_grad=True.
447
Used to remove heads.
448
449
Parameters:
450
- layer (nn.Module): Layer to prune (Conv1D or nn.Linear)
451
- index (torch.LongTensor): Indices of entries to keep
452
- dim (int, optional): Dimension along which to prune (default: 0 for Linear, 1 for Conv1D)
453
454
Returns:
455
nn.Module: New pruned layer
456
"""
457
```
458
459
**Usage Examples:**
460
461
```python
462
from pytorch_transformers import Conv1D, prune_layer
463
import torch
464
import torch.nn as nn
465
466
# Create a Conv1D layer
467
conv1d = Conv1D(768, 512) # 768 output features, 512 input features
468
input_tensor = torch.randn(32, 128, 512) # batch_size, seq_len, input_features
469
output = conv1d(input_tensor)
470
print(output.shape) # torch.Size([32, 128, 768])
471
472
# Prune a linear layer to keep only certain features
473
linear = nn.Linear(768, 12) # Original layer
474
indices_to_keep = torch.LongTensor([0, 2, 4, 6, 8, 10]) # Keep every other feature
475
pruned_linear = prune_layer(linear, indices_to_keep, dim=1)
476
print(f"Original: {linear.weight.shape}, Pruned: {pruned_linear.weight.shape}")
477
```
478
479
### Constants
480
481
File naming constants used throughout the library for consistent model serialization.
482
483
```python { .api }
484
# Model weight files
485
WEIGHTS_NAME: str = "pytorch_model.bin"
486
CONFIG_NAME: str = "config.json"
487
TF_WEIGHTS_NAME: str = "model.ckpt"
488
```
489
490
**Usage Examples:**
491
492
```python
493
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME
494
import os
495
496
# Check for model files in a directory
497
model_dir = "./my-model"
498
weights_path = os.path.join(model_dir, WEIGHTS_NAME)
499
config_path = os.path.join(model_dir, CONFIG_NAME)
500
501
if os.path.exists(weights_path):
502
print(f"PyTorch weights found: {weights_path}")
503
if os.path.exists(config_path):
504
print(f"Config found: {config_path}")
505
506
# When loading TensorFlow weights
507
tf_weights_path = os.path.join(model_dir, TF_WEIGHTS_NAME + ".index")
508
if os.path.exists(tf_weights_path):
509
print(f"TensorFlow weights found: {tf_weights_path}")
510
```