0
# BERT Models
1
2
Complete BERT model family including configuration, base model, and task-specific variants for bidirectional language understanding tasks such as sequence classification, question answering, token classification, and masked language modeling.
3
4
## Capabilities
5
6
### Configuration
7
8
Stores BERT model configuration parameters including architecture dimensions, layer counts, and training hyperparameters.
9
10
```python { .api }
11
class BertConfig:
12
def __init__(
13
self,
14
vocab_size_or_config_json_file,
15
hidden_size=768,
16
num_hidden_layers=12,
17
num_attention_heads=12,
18
intermediate_size=3072,
19
hidden_act="gelu",
20
hidden_dropout_prob=0.1,
21
attention_probs_dropout_prob=0.1,
22
max_position_embeddings=512,
23
type_vocab_size=2,
24
initializer_range=0.02
25
):
26
"""
27
Initialize BERT configuration.
28
29
Args:
30
vocab_size_or_config_json_file (int or str): Vocabulary size or path to config JSON (required)
31
hidden_size (int): Hidden layer dimension
32
num_hidden_layers (int): Number of transformer layers
33
num_attention_heads (int): Number of attention heads
34
intermediate_size (int): Feed-forward layer dimension
35
hidden_act (str): Activation function ('gelu', 'relu', 'swish')
36
hidden_dropout_prob (float): Dropout probability for hidden layers
37
attention_probs_dropout_prob (float): Dropout probability for attention
38
max_position_embeddings (int): Maximum sequence length
39
type_vocab_size (int): Token type vocabulary size
40
initializer_range (float): Weight initialization range
41
"""
42
43
@classmethod
44
def from_dict(cls, json_object):
45
"""Create configuration from dictionary."""
46
47
@classmethod
48
def from_json_file(cls, json_file):
49
"""Create configuration from JSON file."""
50
51
def to_dict(self):
52
"""Convert configuration to dictionary."""
53
54
def to_json_string(self):
55
"""Convert configuration to JSON string."""
56
```
57
58
### Base Model
59
60
The core BERT transformer model outputting raw hidden states without task-specific heads.
61
62
```python { .api }
63
class BertModel:
64
def __init__(self, config, output_attentions=False):
65
"""
66
Initialize BERT base model.
67
68
Args:
69
config (BertConfig): Model configuration
70
output_attentions (bool): Whether to output attention weights
71
"""
72
73
def forward(
74
self,
75
input_ids,
76
token_type_ids=None,
77
attention_mask=None,
78
output_all_encoded_layers=True
79
):
80
"""
81
Forward pass through BERT model.
82
83
Args:
84
input_ids (torch.Tensor): Token IDs of shape [batch_size, seq_len]
85
token_type_ids (torch.Tensor, optional): Segment IDs of shape [batch_size, seq_len]
86
attention_mask (torch.Tensor, optional): Attention mask of shape [batch_size, seq_len]
87
output_all_encoded_layers (bool): Whether to output all layer states
88
89
Returns:
90
tuple: (encoded_layers, pooled_output) where:
91
- encoded_layers (list): Hidden states from each layer
92
- pooled_output (torch.Tensor): Pooled representation for classification
93
"""
94
95
@classmethod
96
def from_pretrained(
97
cls,
98
pretrained_model_name_or_path,
99
cache_dir=None,
100
output_attentions=False,
101
**kwargs
102
):
103
"""Load pre-trained BERT model."""
104
```
105
106
### Pre-training Model
107
108
BERT model with both masked language modeling and next sentence prediction heads for pre-training tasks.
109
110
```python { .api }
111
class BertForPreTraining:
112
def __init__(self, config):
113
"""
114
Initialize BERT for pre-training.
115
116
Args:
117
config (BertConfig): Model configuration
118
"""
119
120
def forward(
121
self,
122
input_ids,
123
token_type_ids=None,
124
attention_mask=None,
125
masked_lm_labels=None,
126
next_sentence_label=None
127
):
128
"""
129
Forward pass with pre-training heads.
130
131
Args:
132
input_ids (torch.Tensor): Token IDs
133
token_type_ids (torch.Tensor, optional): Segment IDs
134
attention_mask (torch.Tensor, optional): Attention mask
135
masked_lm_labels (torch.Tensor, optional): MLM labels for loss computation
136
next_sentence_label (torch.Tensor, optional): NSP labels for loss computation
137
138
Returns:
139
tuple: (prediction_scores, seq_relationship_score) or loss if labels provided
140
"""
141
142
@classmethod
143
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, **kwargs):
144
"""Load pre-trained model."""
145
```
146
147
### Masked Language Modeling
148
149
BERT model with only the masked language modeling head for MLM fine-tuning.
150
151
```python { .api }
152
class BertForMaskedLM:
153
def __init__(self, config):
154
"""
155
Initialize BERT for masked language modeling.
156
157
Args:
158
config (BertConfig): Model configuration
159
"""
160
161
def forward(
162
self,
163
input_ids,
164
token_type_ids=None,
165
attention_mask=None,
166
masked_lm_labels=None
167
):
168
"""
169
Forward pass with MLM head.
170
171
Returns:
172
torch.Tensor: Prediction scores for vocabulary tokens or loss if labels provided
173
"""
174
175
@classmethod
176
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, **kwargs):
177
"""Load pre-trained model."""
178
```
179
180
### Next Sentence Prediction
181
182
BERT model with only the next sentence prediction head for NSP tasks.
183
184
```python { .api }
185
class BertForNextSentencePrediction:
186
def __init__(self, config):
187
"""
188
Initialize BERT for next sentence prediction.
189
190
Args:
191
config (BertConfig): Model configuration
192
"""
193
194
def forward(
195
self,
196
input_ids,
197
token_type_ids=None,
198
attention_mask=None,
199
next_sentence_label=None
200
):
201
"""
202
Forward pass with NSP head.
203
204
Returns:
205
torch.Tensor: Next sentence prediction scores or loss if labels provided
206
"""
207
208
@classmethod
209
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, **kwargs):
210
"""Load pre-trained model."""
211
```
212
213
### Sequence Classification
214
215
BERT model with a classification head for sequence-level tasks like sentiment analysis, text classification, and natural language inference.
216
217
```python { .api }
218
class BertForSequenceClassification:
219
def __init__(self, config, num_labels):
220
"""
221
Initialize BERT for sequence classification.
222
223
Args:
224
config (BertConfig): Model configuration
225
num_labels (int): Number of classification labels
226
"""
227
228
def forward(
229
self,
230
input_ids,
231
token_type_ids=None,
232
attention_mask=None,
233
labels=None
234
):
235
"""
236
Forward pass with classification head.
237
238
Args:
239
input_ids (torch.Tensor): Token IDs
240
token_type_ids (torch.Tensor, optional): Segment IDs
241
attention_mask (torch.Tensor, optional): Attention mask
242
labels (torch.Tensor, optional): Classification labels for loss computation
243
244
Returns:
245
torch.Tensor: Classification logits or loss if labels provided
246
"""
247
248
@classmethod
249
def from_pretrained(
250
cls,
251
pretrained_model_name_or_path,
252
cache_dir=None,
253
num_labels=2,
254
**kwargs
255
):
256
"""Load pre-trained model."""
257
```
258
259
### Multiple Choice
260
261
BERT model for multiple choice tasks where each example consists of multiple candidate choices.
262
263
```python { .api }
264
class BertForMultipleChoice:
265
def __init__(self, config, num_choices):
266
"""
267
Initialize BERT for multiple choice.
268
269
Args:
270
config (BertConfig): Model configuration
271
num_choices (int): Number of choices per example
272
"""
273
274
def forward(
275
self,
276
input_ids,
277
token_type_ids=None,
278
attention_mask=None,
279
labels=None
280
):
281
"""
282
Forward pass with multiple choice head.
283
284
Args:
285
input_ids (torch.Tensor): Token IDs of shape [batch_size, num_choices, seq_len]
286
token_type_ids (torch.Tensor, optional): Segment IDs
287
attention_mask (torch.Tensor, optional): Attention mask
288
labels (torch.Tensor, optional): Choice labels for loss computation
289
290
Returns:
291
torch.Tensor: Choice scores or loss if labels provided
292
"""
293
294
@classmethod
295
def from_pretrained(
296
cls,
297
pretrained_model_name_or_path,
298
cache_dir=None,
299
num_choices=2,
300
**kwargs
301
):
302
"""Load pre-trained model."""
303
```
304
305
### Token Classification
306
307
BERT model with a token-level classification head for tasks like named entity recognition and part-of-speech tagging.
308
309
```python { .api }
310
class BertForTokenClassification:
311
def __init__(self, config, num_labels):
312
"""
313
Initialize BERT for token classification.
314
315
Args:
316
config (BertConfig): Model configuration
317
num_labels (int): Number of token classification labels
318
"""
319
320
def forward(
321
self,
322
input_ids,
323
token_type_ids=None,
324
attention_mask=None,
325
labels=None
326
):
327
"""
328
Forward pass with token classification head.
329
330
Args:
331
input_ids (torch.Tensor): Token IDs
332
token_type_ids (torch.Tensor, optional): Segment IDs
333
attention_mask (torch.Tensor, optional): Attention mask
334
labels (torch.Tensor, optional): Token labels for loss computation
335
336
Returns:
337
torch.Tensor: Token classification logits or loss if labels provided
338
"""
339
340
@classmethod
341
def from_pretrained(
342
cls,
343
pretrained_model_name_or_path,
344
cache_dir=None,
345
num_labels=2,
346
**kwargs
347
):
348
"""Load pre-trained model."""
349
```
350
351
### Question Answering
352
353
BERT model with span-based question answering head for extractive QA tasks like SQuAD.
354
355
```python { .api }
356
class BertForQuestionAnswering:
357
def __init__(self, config):
358
"""
359
Initialize BERT for question answering.
360
361
Args:
362
config (BertConfig): Model configuration
363
"""
364
365
def forward(
366
self,
367
input_ids,
368
token_type_ids=None,
369
attention_mask=None,
370
start_positions=None,
371
end_positions=None
372
):
373
"""
374
Forward pass with QA head.
375
376
Args:
377
input_ids (torch.Tensor): Token IDs
378
token_type_ids (torch.Tensor, optional): Segment IDs
379
attention_mask (torch.Tensor, optional): Attention mask
380
start_positions (torch.Tensor, optional): Start positions for loss computation
381
end_positions (torch.Tensor, optional): End positions for loss computation
382
383
Returns:
384
tuple: (start_scores, end_scores) or loss if positions provided
385
"""
386
387
@classmethod
388
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, **kwargs):
389
"""Load pre-trained model."""
390
```
391
392
## Weight Loading
393
394
Function to load TensorFlow BERT checkpoint weights into PyTorch model.
395
396
```python { .api }
397
def load_tf_weights_in_bert(model, tf_checkpoint_path):
398
"""
399
Load TensorFlow BERT checkpoint into PyTorch model.
400
401
Args:
402
model: PyTorch BERT model instance
403
tf_checkpoint_path (str): Path to TensorFlow checkpoint
404
405
Returns:
406
PyTorch model with loaded weights
407
"""
408
```
409
410
## Usage Examples
411
412
### Basic BERT Model
413
414
```python
415
from pytorch_pretrained_bert import BertModel, BertConfig
416
417
# Create model from configuration
418
config = BertConfig(vocab_size=30522, hidden_size=768)
419
model = BertModel(config)
420
421
# Or load pre-trained model
422
model = BertModel.from_pretrained('bert-base-uncased')
423
```
424
425
### Sequence Classification
426
427
```python
428
from pytorch_pretrained_bert import BertForSequenceClassification
429
import torch
430
431
# Load for 3-class classification
432
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)
433
434
# Forward pass
435
input_ids = torch.tensor([[101, 2023, 2003, 102]]) # [CLS] this is [SEP]
436
outputs = model(input_ids)
437
logits = outputs[0] # Classification scores
438
```
439
440
### Question Answering
441
442
```python
443
from pytorch_pretrained_bert import BertForQuestionAnswering
444
import torch
445
446
# Load QA model
447
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
448
449
# Forward pass with question and passage
450
input_ids = torch.tensor([[101, 2054, 2003, 102, 2023, 2003, 1996, 3437, 102]])
451
token_type_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1]]) # 0=question, 1=passage
452
453
outputs = model(input_ids, token_type_ids=token_type_ids)
454
start_scores, end_scores = outputs
455
```