0
# Text Models
1
2
Comprehensive implementations of transformer models for natural language processing tasks. Keras Hub provides both backbone models (core architectures) and task-specific models with specialized heads for classification, masked language modeling, causal language modeling, and sequence-to-sequence tasks.
3
4
## Capabilities
5
6
### Base Classes
7
8
Foundation classes that define the interface for different types of text models.
9
10
```python { .api }
11
class Task:
12
"""Base class for all tasks."""
13
@classmethod
14
def from_preset(cls, preset: str, **kwargs): ...
15
def compile(self, **kwargs): ...
16
def fit(self, x, y=None, **kwargs): ...
17
def predict(self, x, **kwargs): ...
18
def generate(self, inputs, **kwargs): ...
19
20
class Backbone:
21
"""Base class for model backbones."""
22
@classmethod
23
def from_preset(cls, preset: str, **kwargs): ...
24
25
class CausalLM(Task):
26
"""Base class for causal language models."""
27
def generate(self, inputs, max_length: int = None, **kwargs): ...
28
29
class MaskedLM(Task):
30
"""Base class for masked language models."""
31
...
32
33
class Seq2SeqLM(Task):
34
"""Base class for sequence-to-sequence models."""
35
def generate(self, inputs, max_length: int = None, **kwargs): ...
36
37
class TextClassifier(Task):
38
"""Base class for text classification models."""
39
...
40
41
# Alias
42
Classifier = TextClassifier
43
```
44
45
### BERT (Bidirectional Encoder Representations from Transformers)
46
47
BERT models for bidirectional language understanding, suitable for classification and masked language modeling tasks.
48
49
```python { .api }
50
class BertBackbone(Backbone):
51
"""BERT transformer backbone."""
52
def __init__(
53
self,
54
vocabulary_size: int,
55
num_layers: int,
56
num_heads: int,
57
hidden_dim: int,
58
intermediate_dim: int,
59
dropout: float = 0.1,
60
max_sequence_length: int = 512,
61
**kwargs
62
): ...
63
64
class BertTextClassifier(TextClassifier):
65
"""BERT model for text classification."""
66
def __init__(
67
self,
68
backbone: BertBackbone,
69
num_classes: int,
70
preprocessor: Preprocessor = None,
71
**kwargs
72
): ...
73
74
class BertMaskedLM(MaskedLM):
75
"""BERT model for masked language modeling."""
76
def __init__(
77
self,
78
backbone: BertBackbone,
79
preprocessor: Preprocessor = None,
80
**kwargs
81
): ...
82
83
class BertMaskedLMPreprocessor:
84
"""Preprocessor for BERT masked language modeling."""
85
def __init__(
86
self,
87
tokenizer: BertTokenizer,
88
sequence_length: int = 512,
89
mask_selection_rate: float = 0.15,
90
mask_token_rate: float = 0.8,
91
random_token_rate: float = 0.1,
92
**kwargs
93
): ...
94
95
class BertTextClassifierPreprocessor:
96
"""Preprocessor for BERT text classification."""
97
def __init__(
98
self,
99
tokenizer: BertTokenizer,
100
sequence_length: int = 512,
101
**kwargs
102
): ...
103
104
class BertTokenizer:
105
"""BERT tokenizer using WordPiece algorithm."""
106
def __init__(
107
self,
108
vocabulary: dict = None,
109
lowercase: bool = True,
110
**kwargs
111
): ...
112
113
# Aliases
114
BertClassifier = BertTextClassifier
115
BertPreprocessor = BertTextClassifierPreprocessor
116
```
117
118
### GPT-2 (Generative Pre-trained Transformer 2)
119
120
GPT-2 models for causal language modeling and text generation.
121
122
```python { .api }
123
class GPT2Backbone(Backbone):
124
"""GPT-2 transformer backbone."""
125
def __init__(
126
self,
127
vocabulary_size: int,
128
num_layers: int,
129
num_heads: int,
130
hidden_dim: int,
131
intermediate_dim: int,
132
dropout: float = 0.1,
133
max_sequence_length: int = 1024,
134
**kwargs
135
): ...
136
137
class GPT2CausalLM(CausalLM):
138
"""GPT-2 model for causal language modeling."""
139
def __init__(
140
self,
141
backbone: GPT2Backbone,
142
preprocessor: Preprocessor = None,
143
**kwargs
144
): ...
145
146
class GPT2CausalLMPreprocessor:
147
"""Preprocessor for GPT-2 causal language modeling."""
148
def __init__(
149
self,
150
tokenizer: GPT2Tokenizer,
151
sequence_length: int = 1024,
152
add_start_token: bool = False,
153
add_end_token: bool = False,
154
**kwargs
155
): ...
156
157
class GPT2Preprocessor:
158
"""General preprocessor for GPT-2."""
159
def __init__(
160
self,
161
tokenizer: GPT2Tokenizer,
162
sequence_length: int = 1024,
163
**kwargs
164
): ...
165
166
class GPT2Tokenizer:
167
"""GPT-2 tokenizer using byte-pair encoding."""
168
def __init__(
169
self,
170
vocabulary: dict = None,
171
merges: list = None,
172
**kwargs
173
): ...
174
```
175
176
### RoBERTa (Robustly Optimized BERT Pretraining Approach)
177
178
RoBERTa models optimized for robust performance on downstream tasks.
179
180
```python { .api }
181
class RobertaBackbone(Backbone):
182
"""RoBERTa transformer backbone."""
183
def __init__(
184
self,
185
vocabulary_size: int,
186
num_layers: int,
187
num_heads: int,
188
hidden_dim: int,
189
intermediate_dim: int,
190
dropout: float = 0.1,
191
max_sequence_length: int = 512,
192
**kwargs
193
): ...
194
195
class RobertaTextClassifier(TextClassifier):
196
"""RoBERTa model for text classification."""
197
def __init__(
198
self,
199
backbone: RobertaBackbone,
200
num_classes: int,
201
preprocessor: Preprocessor = None,
202
**kwargs
203
): ...
204
205
class RobertaMaskedLM(MaskedLM):
206
"""RoBERTa model for masked language modeling."""
207
def __init__(
208
self,
209
backbone: RobertaBackbone,
210
preprocessor: Preprocessor = None,
211
**kwargs
212
): ...
213
214
class RobertaMaskedLMPreprocessor:
215
"""Preprocessor for RoBERTa masked language modeling."""
216
def __init__(
217
self,
218
tokenizer: RobertaTokenizer,
219
sequence_length: int = 512,
220
mask_selection_rate: float = 0.15,
221
mask_token_rate: float = 0.8,
222
random_token_rate: float = 0.1,
223
**kwargs
224
): ...
225
226
class RobertaTextClassifierPreprocessor:
227
"""Preprocessor for RoBERTa text classification."""
228
def __init__(
229
self,
230
tokenizer: RobertaTokenizer,
231
sequence_length: int = 512,
232
**kwargs
233
): ...
234
235
class RobertaTokenizer:
236
"""RoBERTa tokenizer using byte-pair encoding."""
237
def __init__(
238
self,
239
vocabulary: dict = None,
240
merges: list = None,
241
**kwargs
242
): ...
243
244
# Aliases
245
RobertaClassifier = RobertaTextClassifier
246
RobertaPreprocessor = RobertaTextClassifierPreprocessor
247
```
248
249
### BART (Bidirectional and Auto-Regressive Transformers)
250
251
BART models for sequence-to-sequence tasks like summarization and translation.
252
253
```python { .api }
254
class BartBackbone(Backbone):
255
"""BART transformer backbone."""
256
def __init__(
257
self,
258
vocabulary_size: int,
259
num_layers: int,
260
num_heads: int,
261
hidden_dim: int,
262
intermediate_dim: int,
263
dropout: float = 0.1,
264
max_sequence_length: int = 1024,
265
**kwargs
266
): ...
267
268
class BartSeq2SeqLM(Seq2SeqLM):
269
"""BART model for sequence-to-sequence tasks."""
270
def __init__(
271
self,
272
backbone: BartBackbone,
273
preprocessor: Preprocessor = None,
274
**kwargs
275
): ...
276
277
class BartSeq2SeqLMPreprocessor:
278
"""Preprocessor for BART sequence-to-sequence modeling."""
279
def __init__(
280
self,
281
tokenizer: BartTokenizer,
282
encoder_sequence_length: int = 1024,
283
decoder_sequence_length: int = 1024,
284
**kwargs
285
): ...
286
287
class BartTokenizer:
288
"""BART tokenizer using byte-pair encoding."""
289
def __init__(
290
self,
291
vocabulary: dict = None,
292
merges: list = None,
293
**kwargs
294
): ...
295
```
296
297
### DistilBERT (Distilled BERT)
298
299
Smaller, faster version of BERT with comparable performance.
300
301
```python { .api }
302
class DistilBertBackbone(Backbone):
303
"""DistilBERT transformer backbone."""
304
def __init__(
305
self,
306
vocabulary_size: int,
307
num_layers: int,
308
num_heads: int,
309
hidden_dim: int,
310
intermediate_dim: int,
311
dropout: float = 0.1,
312
max_sequence_length: int = 512,
313
**kwargs
314
): ...
315
316
class DistilBertTextClassifier(TextClassifier):
317
"""DistilBERT model for text classification."""
318
def __init__(
319
self,
320
backbone: DistilBertBackbone,
321
num_classes: int,
322
preprocessor: Preprocessor = None,
323
**kwargs
324
): ...
325
326
class DistilBertMaskedLM(MaskedLM):
327
"""DistilBERT model for masked language modeling."""
328
def __init__(
329
self,
330
backbone: DistilBertBackbone,
331
preprocessor: Preprocessor = None,
332
**kwargs
333
): ...
334
335
class DistilBertMaskedLMPreprocessor:
336
"""Preprocessor for DistilBERT masked language modeling."""
337
def __init__(
338
self,
339
tokenizer: DistilBertTokenizer,
340
sequence_length: int = 512,
341
mask_selection_rate: float = 0.15,
342
mask_token_rate: float = 0.8,
343
random_token_rate: float = 0.1,
344
**kwargs
345
): ...
346
347
class DistilBertTextClassifierPreprocessor:
348
"""Preprocessor for DistilBERT text classification."""
349
def __init__(
350
self,
351
tokenizer: DistilBertTokenizer,
352
sequence_length: int = 512,
353
**kwargs
354
): ...
355
356
class DistilBertTokenizer:
357
"""DistilBERT tokenizer using WordPiece algorithm."""
358
def __init__(
359
self,
360
vocabulary: dict = None,
361
lowercase: bool = True,
362
**kwargs
363
): ...
364
365
# Aliases
366
DistilBertClassifier = DistilBertTextClassifier
367
DistilBertPreprocessor = DistilBertTextClassifierPreprocessor
368
```
369
370
### Large Language Models
371
372
Modern large language models for advanced text generation and understanding.
373
374
```python { .api }
375
# Llama
376
class LlamaBackbone(Backbone): ...
377
class LlamaCausalLM(CausalLM): ...
378
class LlamaCausalLMPreprocessor: ...
379
class LlamaTokenizer: ...
380
381
# Llama 3
382
class Llama3Backbone(Backbone): ...
383
class Llama3CausalLM(CausalLM): ...
384
class Llama3CausalLMPreprocessor: ...
385
class Llama3Tokenizer: ...
386
387
# Mistral
388
class MistralBackbone(Backbone): ...
389
class MistralCausalLM(CausalLM): ...
390
class MistralCausalLMPreprocessor: ...
391
class MistralTokenizer: ...
392
393
# Mixtral (Mixture of Experts)
394
class MixtralBackbone(Backbone): ...
395
class MixtralCausalLM(CausalLM): ...
396
class MixtralCausalLMPreprocessor: ...
397
class MixtralTokenizer: ...
398
399
# Gemma
400
class GemmaBackbone(Backbone): ...
401
class GemmaCausalLM(CausalLM): ...
402
class GemmaCausalLMPreprocessor: ...
403
class GemmaTokenizer: ...
404
405
# Gemma 3
406
class Gemma3Backbone(Backbone): ...
407
class Gemma3CausalLM(CausalLM): ...
408
class Gemma3CausalLMPreprocessor: ...
409
class Gemma3Tokenizer: ...
410
411
# BLOOM
412
class BloomBackbone(Backbone): ...
413
class BloomCausalLM(CausalLM): ...
414
class BloomCausalLMPreprocessor: ...
415
class BloomTokenizer: ...
416
417
# OPT
418
class OPTBackbone(Backbone): ...
419
class OPTCausalLM(CausalLM): ...
420
class OPTCausalLMPreprocessor: ...
421
class OPTTokenizer: ...
422
423
# GPT-NeoX
424
class GPTNeoXBackbone(Backbone): ...
425
class GPTNeoXCausalLM(CausalLM): ...
426
class GPTNeoXCausalLMPreprocessor: ...
427
class GPTNeoXTokenizer: ...
428
429
# Falcon
430
class FalconBackbone(Backbone): ...
431
class FalconCausalLM(CausalLM): ...
432
class FalconCausalLMPreprocessor: ...
433
class FalconTokenizer: ...
434
435
# Phi-3
436
class Phi3Backbone(Backbone): ...
437
class Phi3CausalLM(CausalLM): ...
438
class Phi3CausalLMPreprocessor: ...
439
class Phi3Tokenizer: ...
440
441
# Qwen / Qwen 2
442
class QwenBackbone(Backbone): ...
443
class QwenCausalLM(CausalLM): ...
444
class QwenCausalLMPreprocessor: ...
445
class QwenTokenizer: ...
446
447
# Aliases for Qwen 2
448
Qwen2Backbone = QwenBackbone
449
Qwen2CausalLM = QwenCausalLM
450
Qwen2CausalLMPreprocessor = QwenCausalLMPreprocessor
451
Qwen2Tokenizer = QwenTokenizer
452
453
# Qwen 3
454
class Qwen3Backbone(Backbone): ...
455
class Qwen3CausalLM(CausalLM): ...
456
class Qwen3CausalLMPreprocessor: ...
457
class Qwen3Tokenizer: ...
458
459
# Qwen MoE
460
class QwenMoeBackbone(Backbone): ...
461
class QwenMoeCausalLM(CausalLM): ...
462
class QwenMoeCausalLMPreprocessor: ...
463
class QwenMoeTokenizer: ...
464
```
465
466
### Specialized Text Models
467
468
Additional text models for specific domains and tasks.
469
470
```python { .api }
471
# ALBERT (A Lite BERT)
472
class AlbertBackbone(Backbone): ...
473
class AlbertTextClassifier(TextClassifier): ...
474
class AlbertMaskedLM(MaskedLM): ...
475
class AlbertMaskedLMPreprocessor: ...
476
class AlbertTextClassifierPreprocessor: ...
477
class AlbertTokenizer: ...
478
479
# Aliases
480
AlbertClassifier = AlbertTextClassifier
481
AlbertPreprocessor = AlbertTextClassifierPreprocessor
482
483
# DeBERTa V3 (Decoding-enhanced BERT with Disentangled Attention)
484
class DebertaV3Backbone(Backbone): ...
485
class DebertaV3TextClassifier(TextClassifier): ...
486
class DebertaV3MaskedLM(MaskedLM): ...
487
class DebertaV3MaskedLMPreprocessor: ...
488
class DebertaV3TextClassifierPreprocessor: ...
489
class DebertaV3Tokenizer: ...
490
491
# Aliases
492
DebertaV3Classifier = DebertaV3TextClassifier
493
DebertaV3Preprocessor = DebertaV3TextClassifierPreprocessor
494
495
# ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)
496
class ElectraBackbone(Backbone): ...
497
class ElectraTokenizer: ...
498
499
# F-Net (Fourier Transform-based Transformer)
500
class FNetBackbone(Backbone): ...
501
class FNetTextClassifier(TextClassifier): ...
502
class FNetMaskedLM(MaskedLM): ...
503
class FNetMaskedLMPreprocessor: ...
504
class FNetTextClassifierPreprocessor: ...
505
class FNetTokenizer: ...
506
507
# Aliases
508
FNetClassifier = FNetTextClassifier
509
FNetPreprocessor = FNetTextClassifierPreprocessor
510
511
# XLM-RoBERTa (Cross-lingual Language Model - RoBERTa)
512
class XLMRobertaBackbone(Backbone): ...
513
class XLMRobertaTextClassifier(TextClassifier): ...
514
class XLMRobertaMaskedLM(MaskedLM): ...
515
class XLMRobertaMaskedLMPreprocessor: ...
516
class XLMRobertaTextClassifierPreprocessor: ...
517
class XLMRobertaTokenizer: ...
518
519
# Aliases
520
XLMRobertaClassifier = XLMRobertaTextClassifier
521
XLMRobertaPreprocessor = XLMRobertaTextClassifierPreprocessor
522
523
# XLNet
524
class XLNetBackbone(Backbone): ...
525
526
# RoFormer V2 (Rotary Position Embedding Transformer V2)
527
class RoformerV2Backbone(Backbone): ...
528
class RoformerV2TextClassifier(TextClassifier): ...
529
class RoformerV2MaskedLM(MaskedLM): ...
530
class RoformerV2MaskedLMPreprocessor: ...
531
class RoformerV2TextClassifierPreprocessor: ...
532
class RoformerV2Tokenizer: ...
533
534
# T5 (Text-To-Text Transfer Transformer)
535
class T5Backbone(Backbone): ...
536
class T5Preprocessor: ...
537
class T5Tokenizer: ...
538
539
# ESM (Evolutionary Scale Modeling) - Protein Language Models
540
class ESMBackbone(Backbone): ...
541
class ESMProteinClassifier: ...
542
class ESMProteinClassifierPreprocessor: ...
543
class ESMMaskedPLM: ...
544
class ESMMaskedPLMPreprocessor: ...
545
class ESMTokenizer: ...
546
547
# Aliases
548
ESM2Backbone = ESMBackbone
549
ESM2MaskedPLM = ESMMaskedPLM
550
```
551
552
### Preprocessor Base Classes
553
554
Base classes for text preprocessing.
555
556
```python { .api }
557
class Preprocessor:
558
"""Base class for preprocessors."""
559
@classmethod
560
def from_preset(cls, preset: str, **kwargs): ...
561
def __call__(self, x, y=None, sample_weight=None): ...
562
563
class CausalLMPreprocessor(Preprocessor):
564
"""Base preprocessor for causal language models."""
565
def __init__(
566
self,
567
tokenizer: Tokenizer,
568
sequence_length: int = 1024,
569
add_start_token: bool = False,
570
add_end_token: bool = False,
571
**kwargs
572
): ...
573
574
class MaskedLMPreprocessor(Preprocessor):
575
"""Base preprocessor for masked language models."""
576
def __init__(
577
self,
578
tokenizer: Tokenizer,
579
sequence_length: int = 512,
580
mask_selection_rate: float = 0.15,
581
mask_token_rate: float = 0.8,
582
random_token_rate: float = 0.1,
583
**kwargs
584
): ...
585
586
class Seq2SeqLMPreprocessor(Preprocessor):
587
"""Base preprocessor for sequence-to-sequence models."""
588
def __init__(
589
self,
590
tokenizer: Tokenizer,
591
encoder_sequence_length: int = 1024,
592
decoder_sequence_length: int = 1024,
593
**kwargs
594
): ...
595
596
class TextClassifierPreprocessor(Preprocessor):
597
"""Base preprocessor for text classification."""
598
def __init__(
599
self,
600
tokenizer: Tokenizer,
601
sequence_length: int = 512,
602
**kwargs
603
): ...
604
```
605
606
## Usage Examples
607
608
### Text Classification with BERT
609
610
```python
611
import keras_hub
612
613
# Load pretrained BERT classifier
614
classifier = keras_hub.models.BertTextClassifier.from_preset(
615
"bert_base_en",
616
num_classes=2 # Binary classification
617
)
618
619
# Compile model
620
classifier.compile(
621
optimizer="adam",
622
loss="sparse_categorical_crossentropy",
623
metrics=["accuracy"]
624
)
625
626
# Prepare data
627
train_texts = ["This movie is great!", "I didn't like this film."]
628
train_labels = [1, 0]
629
630
# Train
631
classifier.fit(train_texts, train_labels, epochs=3)
632
633
# Predict
634
predictions = classifier.predict(["A wonderful story!"])
635
print(predictions)
636
```
637
638
### Text Generation with GPT-2
639
640
```python
641
import keras_hub
642
643
# Load pretrained GPT-2 model
644
generator = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
645
646
# Generate text
647
prompt = "The future of artificial intelligence is"
648
generated = generator.generate(prompt, max_length=100)
649
print(generated)
650
651
# Control generation with sampling
652
sampler = keras_hub.samplers.TopKSampler(k=50, temperature=0.8)
653
generated = generator.generate(prompt, max_length=100, sampler=sampler)
654
print(generated)
655
```
656
657
### Masked Language Modeling with RoBERTa
658
659
```python
660
import keras_hub
661
662
# Load RoBERTa masked LM
663
model = keras_hub.models.RobertaMaskedLM.from_preset("roberta_base_en")
664
665
# Predict masked tokens
666
text_with_mask = "The capital of France is [MASK]."
667
predictions = model.predict([text_with_mask])
668
print(predictions)
669
```