or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio-models.mdevaluation-metrics.mdgenerative-models.mdimage-models.mdindex.mdlayers-components.mdmultimodal-models.mdtext-generation-sampling.mdtext-models.mdtokenizers.mdutilities-helpers.md

text-models.mddocs/

0

# Text Models

1

2

Comprehensive implementations of transformer models for natural language processing tasks. Keras Hub provides both backbone models (core architectures) and task-specific models with specialized heads for classification, masked language modeling, causal language modeling, and sequence-to-sequence tasks.

3

4

## Capabilities

5

6

### Base Classes

7

8

Foundation classes that define the interface for different types of text models.

9

10

```python { .api }

11

class Task:

12

"""Base class for all tasks."""

13

@classmethod

14

def from_preset(cls, preset: str, **kwargs): ...

15

def compile(self, **kwargs): ...

16

def fit(self, x, y=None, **kwargs): ...

17

def predict(self, x, **kwargs): ...

18

def generate(self, inputs, **kwargs): ...

19

20

class Backbone:

21

"""Base class for model backbones."""

22

@classmethod

23

def from_preset(cls, preset: str, **kwargs): ...

24

25

class CausalLM(Task):

26

"""Base class for causal language models."""

27

def generate(self, inputs, max_length: int = None, **kwargs): ...

28

29

class MaskedLM(Task):

30

"""Base class for masked language models."""

31

...

32

33

class Seq2SeqLM(Task):

34

"""Base class for sequence-to-sequence models."""

35

def generate(self, inputs, max_length: int = None, **kwargs): ...

36

37

class TextClassifier(Task):

38

"""Base class for text classification models."""

39

...

40

41

# Alias

42

Classifier = TextClassifier

43

```

44

45

### BERT (Bidirectional Encoder Representations from Transformers)

46

47

BERT models for bidirectional language understanding, suitable for classification and masked language modeling tasks.

48

49

```python { .api }

50

class BertBackbone(Backbone):

51

"""BERT transformer backbone."""

52

def __init__(

53

self,

54

vocabulary_size: int,

55

num_layers: int,

56

num_heads: int,

57

hidden_dim: int,

58

intermediate_dim: int,

59

dropout: float = 0.1,

60

max_sequence_length: int = 512,

61

**kwargs

62

): ...

63

64

class BertTextClassifier(TextClassifier):

65

"""BERT model for text classification."""

66

def __init__(

67

self,

68

backbone: BertBackbone,

69

num_classes: int,

70

preprocessor: Preprocessor = None,

71

**kwargs

72

): ...

73

74

class BertMaskedLM(MaskedLM):

75

"""BERT model for masked language modeling."""

76

def __init__(

77

self,

78

backbone: BertBackbone,

79

preprocessor: Preprocessor = None,

80

**kwargs

81

): ...

82

83

class BertMaskedLMPreprocessor:

84

"""Preprocessor for BERT masked language modeling."""

85

def __init__(

86

self,

87

tokenizer: BertTokenizer,

88

sequence_length: int = 512,

89

mask_selection_rate: float = 0.15,

90

mask_token_rate: float = 0.8,

91

random_token_rate: float = 0.1,

92

**kwargs

93

): ...

94

95

class BertTextClassifierPreprocessor:

96

"""Preprocessor for BERT text classification."""

97

def __init__(

98

self,

99

tokenizer: BertTokenizer,

100

sequence_length: int = 512,

101

**kwargs

102

): ...

103

104

class BertTokenizer:

105

"""BERT tokenizer using WordPiece algorithm."""

106

def __init__(

107

self,

108

vocabulary: dict = None,

109

lowercase: bool = True,

110

**kwargs

111

): ...

112

113

# Aliases

114

BertClassifier = BertTextClassifier

115

BertPreprocessor = BertTextClassifierPreprocessor

116

```

117

118

### GPT-2 (Generative Pre-trained Transformer 2)

119

120

GPT-2 models for causal language modeling and text generation.

121

122

```python { .api }

123

class GPT2Backbone(Backbone):

124

"""GPT-2 transformer backbone."""

125

def __init__(

126

self,

127

vocabulary_size: int,

128

num_layers: int,

129

num_heads: int,

130

hidden_dim: int,

131

intermediate_dim: int,

132

dropout: float = 0.1,

133

max_sequence_length: int = 1024,

134

**kwargs

135

): ...

136

137

class GPT2CausalLM(CausalLM):

138

"""GPT-2 model for causal language modeling."""

139

def __init__(

140

self,

141

backbone: GPT2Backbone,

142

preprocessor: Preprocessor = None,

143

**kwargs

144

): ...

145

146

class GPT2CausalLMPreprocessor:

147

"""Preprocessor for GPT-2 causal language modeling."""

148

def __init__(

149

self,

150

tokenizer: GPT2Tokenizer,

151

sequence_length: int = 1024,

152

add_start_token: bool = False,

153

add_end_token: bool = False,

154

**kwargs

155

): ...

156

157

class GPT2Preprocessor:

158

"""General preprocessor for GPT-2."""

159

def __init__(

160

self,

161

tokenizer: GPT2Tokenizer,

162

sequence_length: int = 1024,

163

**kwargs

164

): ...

165

166

class GPT2Tokenizer:

167

"""GPT-2 tokenizer using byte-pair encoding."""

168

def __init__(

169

self,

170

vocabulary: dict = None,

171

merges: list = None,

172

**kwargs

173

): ...

174

```

175

176

### RoBERTa (Robustly Optimized BERT Pretraining Approach)

177

178

RoBERTa models optimized for robust performance on downstream tasks.

179

180

```python { .api }

181

class RobertaBackbone(Backbone):

182

"""RoBERTa transformer backbone."""

183

def __init__(

184

self,

185

vocabulary_size: int,

186

num_layers: int,

187

num_heads: int,

188

hidden_dim: int,

189

intermediate_dim: int,

190

dropout: float = 0.1,

191

max_sequence_length: int = 512,

192

**kwargs

193

): ...

194

195

class RobertaTextClassifier(TextClassifier):

196

"""RoBERTa model for text classification."""

197

def __init__(

198

self,

199

backbone: RobertaBackbone,

200

num_classes: int,

201

preprocessor: Preprocessor = None,

202

**kwargs

203

): ...

204

205

class RobertaMaskedLM(MaskedLM):

206

"""RoBERTa model for masked language modeling."""

207

def __init__(

208

self,

209

backbone: RobertaBackbone,

210

preprocessor: Preprocessor = None,

211

**kwargs

212

): ...

213

214

class RobertaMaskedLMPreprocessor:

215

"""Preprocessor for RoBERTa masked language modeling."""

216

def __init__(

217

self,

218

tokenizer: RobertaTokenizer,

219

sequence_length: int = 512,

220

mask_selection_rate: float = 0.15,

221

mask_token_rate: float = 0.8,

222

random_token_rate: float = 0.1,

223

**kwargs

224

): ...

225

226

class RobertaTextClassifierPreprocessor:

227

"""Preprocessor for RoBERTa text classification."""

228

def __init__(

229

self,

230

tokenizer: RobertaTokenizer,

231

sequence_length: int = 512,

232

**kwargs

233

): ...

234

235

class RobertaTokenizer:

236

"""RoBERTa tokenizer using byte-pair encoding."""

237

def __init__(

238

self,

239

vocabulary: dict = None,

240

merges: list = None,

241

**kwargs

242

): ...

243

244

# Aliases

245

RobertaClassifier = RobertaTextClassifier

246

RobertaPreprocessor = RobertaTextClassifierPreprocessor

247

```

248

249

### BART (Bidirectional and Auto-Regressive Transformers)

250

251

BART models for sequence-to-sequence tasks like summarization and translation.

252

253

```python { .api }

254

class BartBackbone(Backbone):

255

"""BART transformer backbone."""

256

def __init__(

257

self,

258

vocabulary_size: int,

259

num_layers: int,

260

num_heads: int,

261

hidden_dim: int,

262

intermediate_dim: int,

263

dropout: float = 0.1,

264

max_sequence_length: int = 1024,

265

**kwargs

266

): ...

267

268

class BartSeq2SeqLM(Seq2SeqLM):

269

"""BART model for sequence-to-sequence tasks."""

270

def __init__(

271

self,

272

backbone: BartBackbone,

273

preprocessor: Preprocessor = None,

274

**kwargs

275

): ...

276

277

class BartSeq2SeqLMPreprocessor:

278

"""Preprocessor for BART sequence-to-sequence modeling."""

279

def __init__(

280

self,

281

tokenizer: BartTokenizer,

282

encoder_sequence_length: int = 1024,

283

decoder_sequence_length: int = 1024,

284

**kwargs

285

): ...

286

287

class BartTokenizer:

288

"""BART tokenizer using byte-pair encoding."""

289

def __init__(

290

self,

291

vocabulary: dict = None,

292

merges: list = None,

293

**kwargs

294

): ...

295

```

296

297

### DistilBERT (Distilled BERT)

298

299

Smaller, faster version of BERT with comparable performance.

300

301

```python { .api }

302

class DistilBertBackbone(Backbone):

303

"""DistilBERT transformer backbone."""

304

def __init__(

305

self,

306

vocabulary_size: int,

307

num_layers: int,

308

num_heads: int,

309

hidden_dim: int,

310

intermediate_dim: int,

311

dropout: float = 0.1,

312

max_sequence_length: int = 512,

313

**kwargs

314

): ...

315

316

class DistilBertTextClassifier(TextClassifier):

317

"""DistilBERT model for text classification."""

318

def __init__(

319

self,

320

backbone: DistilBertBackbone,

321

num_classes: int,

322

preprocessor: Preprocessor = None,

323

**kwargs

324

): ...

325

326

class DistilBertMaskedLM(MaskedLM):

327

"""DistilBERT model for masked language modeling."""

328

def __init__(

329

self,

330

backbone: DistilBertBackbone,

331

preprocessor: Preprocessor = None,

332

**kwargs

333

): ...

334

335

class DistilBertMaskedLMPreprocessor:

336

"""Preprocessor for DistilBERT masked language modeling."""

337

def __init__(

338

self,

339

tokenizer: DistilBertTokenizer,

340

sequence_length: int = 512,

341

mask_selection_rate: float = 0.15,

342

mask_token_rate: float = 0.8,

343

random_token_rate: float = 0.1,

344

**kwargs

345

): ...

346

347

class DistilBertTextClassifierPreprocessor:

348

"""Preprocessor for DistilBERT text classification."""

349

def __init__(

350

self,

351

tokenizer: DistilBertTokenizer,

352

sequence_length: int = 512,

353

**kwargs

354

): ...

355

356

class DistilBertTokenizer:

357

"""DistilBERT tokenizer using WordPiece algorithm."""

358

def __init__(

359

self,

360

vocabulary: dict = None,

361

lowercase: bool = True,

362

**kwargs

363

): ...

364

365

# Aliases

366

DistilBertClassifier = DistilBertTextClassifier

367

DistilBertPreprocessor = DistilBertTextClassifierPreprocessor

368

```

369

370

### Large Language Models

371

372

Modern large language models for advanced text generation and understanding.

373

374

```python { .api }

375

# Llama

376

class LlamaBackbone(Backbone): ...

377

class LlamaCausalLM(CausalLM): ...

378

class LlamaCausalLMPreprocessor: ...

379

class LlamaTokenizer: ...

380

381

# Llama 3

382

class Llama3Backbone(Backbone): ...

383

class Llama3CausalLM(CausalLM): ...

384

class Llama3CausalLMPreprocessor: ...

385

class Llama3Tokenizer: ...

386

387

# Mistral

388

class MistralBackbone(Backbone): ...

389

class MistralCausalLM(CausalLM): ...

390

class MistralCausalLMPreprocessor: ...

391

class MistralTokenizer: ...

392

393

# Mixtral (Mixture of Experts)

394

class MixtralBackbone(Backbone): ...

395

class MixtralCausalLM(CausalLM): ...

396

class MixtralCausalLMPreprocessor: ...

397

class MixtralTokenizer: ...

398

399

# Gemma

400

class GemmaBackbone(Backbone): ...

401

class GemmaCausalLM(CausalLM): ...

402

class GemmaCausalLMPreprocessor: ...

403

class GemmaTokenizer: ...

404

405

# Gemma 3

406

class Gemma3Backbone(Backbone): ...

407

class Gemma3CausalLM(CausalLM): ...

408

class Gemma3CausalLMPreprocessor: ...

409

class Gemma3Tokenizer: ...

410

411

# BLOOM

412

class BloomBackbone(Backbone): ...

413

class BloomCausalLM(CausalLM): ...

414

class BloomCausalLMPreprocessor: ...

415

class BloomTokenizer: ...

416

417

# OPT

418

class OPTBackbone(Backbone): ...

419

class OPTCausalLM(CausalLM): ...

420

class OPTCausalLMPreprocessor: ...

421

class OPTTokenizer: ...

422

423

# GPT-NeoX

424

class GPTNeoXBackbone(Backbone): ...

425

class GPTNeoXCausalLM(CausalLM): ...

426

class GPTNeoXCausalLMPreprocessor: ...

427

class GPTNeoXTokenizer: ...

428

429

# Falcon

430

class FalconBackbone(Backbone): ...

431

class FalconCausalLM(CausalLM): ...

432

class FalconCausalLMPreprocessor: ...

433

class FalconTokenizer: ...

434

435

# Phi-3

436

class Phi3Backbone(Backbone): ...

437

class Phi3CausalLM(CausalLM): ...

438

class Phi3CausalLMPreprocessor: ...

439

class Phi3Tokenizer: ...

440

441

# Qwen / Qwen 2

442

class QwenBackbone(Backbone): ...

443

class QwenCausalLM(CausalLM): ...

444

class QwenCausalLMPreprocessor: ...

445

class QwenTokenizer: ...

446

447

# Aliases for Qwen 2

448

Qwen2Backbone = QwenBackbone

449

Qwen2CausalLM = QwenCausalLM

450

Qwen2CausalLMPreprocessor = QwenCausalLMPreprocessor

451

Qwen2Tokenizer = QwenTokenizer

452

453

# Qwen 3

454

class Qwen3Backbone(Backbone): ...

455

class Qwen3CausalLM(CausalLM): ...

456

class Qwen3CausalLMPreprocessor: ...

457

class Qwen3Tokenizer: ...

458

459

# Qwen MoE

460

class QwenMoeBackbone(Backbone): ...

461

class QwenMoeCausalLM(CausalLM): ...

462

class QwenMoeCausalLMPreprocessor: ...

463

class QwenMoeTokenizer: ...

464

```

465

466

### Specialized Text Models

467

468

Additional text models for specific domains and tasks.

469

470

```python { .api }

471

# ALBERT (A Lite BERT)

472

class AlbertBackbone(Backbone): ...

473

class AlbertTextClassifier(TextClassifier): ...

474

class AlbertMaskedLM(MaskedLM): ...

475

class AlbertMaskedLMPreprocessor: ...

476

class AlbertTextClassifierPreprocessor: ...

477

class AlbertTokenizer: ...

478

479

# Aliases

480

AlbertClassifier = AlbertTextClassifier

481

AlbertPreprocessor = AlbertTextClassifierPreprocessor

482

483

# DeBERTa V3 (Decoding-enhanced BERT with Disentangled Attention)

484

class DebertaV3Backbone(Backbone): ...

485

class DebertaV3TextClassifier(TextClassifier): ...

486

class DebertaV3MaskedLM(MaskedLM): ...

487

class DebertaV3MaskedLMPreprocessor: ...

488

class DebertaV3TextClassifierPreprocessor: ...

489

class DebertaV3Tokenizer: ...

490

491

# Aliases

492

DebertaV3Classifier = DebertaV3TextClassifier

493

DebertaV3Preprocessor = DebertaV3TextClassifierPreprocessor

494

495

# ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)

496

class ElectraBackbone(Backbone): ...

497

class ElectraTokenizer: ...

498

499

# F-Net (Fourier Transform-based Transformer)

500

class FNetBackbone(Backbone): ...

501

class FNetTextClassifier(TextClassifier): ...

502

class FNetMaskedLM(MaskedLM): ...

503

class FNetMaskedLMPreprocessor: ...

504

class FNetTextClassifierPreprocessor: ...

505

class FNetTokenizer: ...

506

507

# Aliases

508

FNetClassifier = FNetTextClassifier

509

FNetPreprocessor = FNetTextClassifierPreprocessor

510

511

# XLM-RoBERTa (Cross-lingual Language Model - RoBERTa)

512

class XLMRobertaBackbone(Backbone): ...

513

class XLMRobertaTextClassifier(TextClassifier): ...

514

class XLMRobertaMaskedLM(MaskedLM): ...

515

class XLMRobertaMaskedLMPreprocessor: ...

516

class XLMRobertaTextClassifierPreprocessor: ...

517

class XLMRobertaTokenizer: ...

518

519

# Aliases

520

XLMRobertaClassifier = XLMRobertaTextClassifier

521

XLMRobertaPreprocessor = XLMRobertaTextClassifierPreprocessor

522

523

# XLNet

524

class XLNetBackbone(Backbone): ...

525

526

# RoFormer V2 (Rotary Position Embedding Transformer V2)

527

class RoformerV2Backbone(Backbone): ...

528

class RoformerV2TextClassifier(TextClassifier): ...

529

class RoformerV2MaskedLM(MaskedLM): ...

530

class RoformerV2MaskedLMPreprocessor: ...

531

class RoformerV2TextClassifierPreprocessor: ...

532

class RoformerV2Tokenizer: ...

533

534

# T5 (Text-To-Text Transfer Transformer)

535

class T5Backbone(Backbone): ...

536

class T5Preprocessor: ...

537

class T5Tokenizer: ...

538

539

# ESM (Evolutionary Scale Modeling) - Protein Language Models

540

class ESMBackbone(Backbone): ...

541

class ESMProteinClassifier: ...

542

class ESMProteinClassifierPreprocessor: ...

543

class ESMMaskedPLM: ...

544

class ESMMaskedPLMPreprocessor: ...

545

class ESMTokenizer: ...

546

547

# Aliases

548

ESM2Backbone = ESMBackbone

549

ESM2MaskedPLM = ESMMaskedPLM

550

```

551

552

### Preprocessor Base Classes

553

554

Base classes for text preprocessing.

555

556

```python { .api }

557

class Preprocessor:

558

"""Base class for preprocessors."""

559

@classmethod

560

def from_preset(cls, preset: str, **kwargs): ...

561

def __call__(self, x, y=None, sample_weight=None): ...

562

563

class CausalLMPreprocessor(Preprocessor):

564

"""Base preprocessor for causal language models."""

565

def __init__(

566

self,

567

tokenizer: Tokenizer,

568

sequence_length: int = 1024,

569

add_start_token: bool = False,

570

add_end_token: bool = False,

571

**kwargs

572

): ...

573

574

class MaskedLMPreprocessor(Preprocessor):

575

"""Base preprocessor for masked language models."""

576

def __init__(

577

self,

578

tokenizer: Tokenizer,

579

sequence_length: int = 512,

580

mask_selection_rate: float = 0.15,

581

mask_token_rate: float = 0.8,

582

random_token_rate: float = 0.1,

583

**kwargs

584

): ...

585

586

class Seq2SeqLMPreprocessor(Preprocessor):

587

"""Base preprocessor for sequence-to-sequence models."""

588

def __init__(

589

self,

590

tokenizer: Tokenizer,

591

encoder_sequence_length: int = 1024,

592

decoder_sequence_length: int = 1024,

593

**kwargs

594

): ...

595

596

class TextClassifierPreprocessor(Preprocessor):

597

"""Base preprocessor for text classification."""

598

def __init__(

599

self,

600

tokenizer: Tokenizer,

601

sequence_length: int = 512,

602

**kwargs

603

): ...

604

```

605

606

## Usage Examples

607

608

### Text Classification with BERT

609

610

```python

611

import keras_hub

612

613

# Load pretrained BERT classifier

614

classifier = keras_hub.models.BertTextClassifier.from_preset(

615

"bert_base_en",

616

num_classes=2 # Binary classification

617

)

618

619

# Compile model

620

classifier.compile(

621

optimizer="adam",

622

loss="sparse_categorical_crossentropy",

623

metrics=["accuracy"]

624

)

625

626

# Prepare data

627

train_texts = ["This movie is great!", "I didn't like this film."]

628

train_labels = [1, 0]

629

630

# Train

631

classifier.fit(train_texts, train_labels, epochs=3)

632

633

# Predict

634

predictions = classifier.predict(["A wonderful story!"])

635

print(predictions)

636

```

637

638

### Text Generation with GPT-2

639

640

```python

641

import keras_hub

642

643

# Load pretrained GPT-2 model

644

generator = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

645

646

# Generate text

647

prompt = "The future of artificial intelligence is"

648

generated = generator.generate(prompt, max_length=100)

649

print(generated)

650

651

# Control generation with sampling

652

sampler = keras_hub.samplers.TopKSampler(k=50, temperature=0.8)

653

generated = generator.generate(prompt, max_length=100, sampler=sampler)

654

print(generated)

655

```

656

657

### Masked Language Modeling with RoBERTa

658

659

```python

660

import keras_hub

661

662

# Load RoBERTa masked LM

663

model = keras_hub.models.RobertaMaskedLM.from_preset("roberta_base_en")

664

665

# Predict masked tokens

666

text_with_mask = "The capital of France is [MASK]."

667

predictions = model.predict([text_with_mask])

668

print(predictions)

669

```