Pretrained models for Keras with multi-framework compatibility.
—
Comprehensive implementations of transformer models for natural language processing tasks. Keras Hub provides both backbone models (core architectures) and task-specific models with specialized heads for classification, masked language modeling, causal language modeling, and sequence-to-sequence tasks.
Foundation classes that define the interface for different types of text models.
class Task:
"""Base class for all tasks."""
@classmethod
def from_preset(cls, preset: str, **kwargs): ...
def compile(self, **kwargs): ...
def fit(self, x, y=None, **kwargs): ...
def predict(self, x, **kwargs): ...
def generate(self, inputs, **kwargs): ...
class Backbone:
"""Base class for model backbones."""
@classmethod
def from_preset(cls, preset: str, **kwargs): ...
class CausalLM(Task):
"""Base class for causal language models."""
def generate(self, inputs, max_length: int = None, **kwargs): ...
class MaskedLM(Task):
"""Base class for masked language models."""
...
class Seq2SeqLM(Task):
"""Base class for sequence-to-sequence models."""
def generate(self, inputs, max_length: int = None, **kwargs): ...
class TextClassifier(Task):
"""Base class for text classification models."""
...
# Alias
Classifier = TextClassifierBERT models for bidirectional language understanding, suitable for classification and masked language modeling tasks.
class BertBackbone(Backbone):
"""BERT transformer backbone."""
def __init__(
self,
vocabulary_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
max_sequence_length: int = 512,
**kwargs
): ...
class BertTextClassifier(TextClassifier):
"""BERT model for text classification."""
def __init__(
self,
backbone: BertBackbone,
num_classes: int,
preprocessor: Preprocessor = None,
**kwargs
): ...
class BertMaskedLM(MaskedLM):
"""BERT model for masked language modeling."""
def __init__(
self,
backbone: BertBackbone,
preprocessor: Preprocessor = None,
**kwargs
): ...
class BertMaskedLMPreprocessor:
"""Preprocessor for BERT masked language modeling."""
def __init__(
self,
tokenizer: BertTokenizer,
sequence_length: int = 512,
mask_selection_rate: float = 0.15,
mask_token_rate: float = 0.8,
random_token_rate: float = 0.1,
**kwargs
): ...
class BertTextClassifierPreprocessor:
"""Preprocessor for BERT text classification."""
def __init__(
self,
tokenizer: BertTokenizer,
sequence_length: int = 512,
**kwargs
): ...
class BertTokenizer:
"""BERT tokenizer using WordPiece algorithm."""
def __init__(
self,
vocabulary: dict = None,
lowercase: bool = True,
**kwargs
): ...
# Aliases
BertClassifier = BertTextClassifier
BertPreprocessor = BertTextClassifierPreprocessorGPT-2 models for causal language modeling and text generation.
class GPT2Backbone(Backbone):
"""GPT-2 transformer backbone."""
def __init__(
self,
vocabulary_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
max_sequence_length: int = 1024,
**kwargs
): ...
class GPT2CausalLM(CausalLM):
"""GPT-2 model for causal language modeling."""
def __init__(
self,
backbone: GPT2Backbone,
preprocessor: Preprocessor = None,
**kwargs
): ...
class GPT2CausalLMPreprocessor:
"""Preprocessor for GPT-2 causal language modeling."""
def __init__(
self,
tokenizer: GPT2Tokenizer,
sequence_length: int = 1024,
add_start_token: bool = False,
add_end_token: bool = False,
**kwargs
): ...
class GPT2Preprocessor:
"""General preprocessor for GPT-2."""
def __init__(
self,
tokenizer: GPT2Tokenizer,
sequence_length: int = 1024,
**kwargs
): ...
class GPT2Tokenizer:
"""GPT-2 tokenizer using byte-pair encoding."""
def __init__(
self,
vocabulary: dict = None,
merges: list = None,
**kwargs
): ...RoBERTa models optimized for robust performance on downstream tasks.
class RobertaBackbone(Backbone):
"""RoBERTa transformer backbone."""
def __init__(
self,
vocabulary_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
max_sequence_length: int = 512,
**kwargs
): ...
class RobertaTextClassifier(TextClassifier):
"""RoBERTa model for text classification."""
def __init__(
self,
backbone: RobertaBackbone,
num_classes: int,
preprocessor: Preprocessor = None,
**kwargs
): ...
class RobertaMaskedLM(MaskedLM):
"""RoBERTa model for masked language modeling."""
def __init__(
self,
backbone: RobertaBackbone,
preprocessor: Preprocessor = None,
**kwargs
): ...
class RobertaMaskedLMPreprocessor:
"""Preprocessor for RoBERTa masked language modeling."""
def __init__(
self,
tokenizer: RobertaTokenizer,
sequence_length: int = 512,
mask_selection_rate: float = 0.15,
mask_token_rate: float = 0.8,
random_token_rate: float = 0.1,
**kwargs
): ...
class RobertaTextClassifierPreprocessor:
"""Preprocessor for RoBERTa text classification."""
def __init__(
self,
tokenizer: RobertaTokenizer,
sequence_length: int = 512,
**kwargs
): ...
class RobertaTokenizer:
"""RoBERTa tokenizer using byte-pair encoding."""
def __init__(
self,
vocabulary: dict = None,
merges: list = None,
**kwargs
): ...
# Aliases
RobertaClassifier = RobertaTextClassifier
RobertaPreprocessor = RobertaTextClassifierPreprocessorBART models for sequence-to-sequence tasks like summarization and translation.
class BartBackbone(Backbone):
"""BART transformer backbone."""
def __init__(
self,
vocabulary_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
max_sequence_length: int = 1024,
**kwargs
): ...
class BartSeq2SeqLM(Seq2SeqLM):
"""BART model for sequence-to-sequence tasks."""
def __init__(
self,
backbone: BartBackbone,
preprocessor: Preprocessor = None,
**kwargs
): ...
class BartSeq2SeqLMPreprocessor:
"""Preprocessor for BART sequence-to-sequence modeling."""
def __init__(
self,
tokenizer: BartTokenizer,
encoder_sequence_length: int = 1024,
decoder_sequence_length: int = 1024,
**kwargs
): ...
class BartTokenizer:
"""BART tokenizer using byte-pair encoding."""
def __init__(
self,
vocabulary: dict = None,
merges: list = None,
**kwargs
): ...Smaller, faster version of BERT with comparable performance.
class DistilBertBackbone(Backbone):
"""DistilBERT transformer backbone."""
def __init__(
self,
vocabulary_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
max_sequence_length: int = 512,
**kwargs
): ...
class DistilBertTextClassifier(TextClassifier):
"""DistilBERT model for text classification."""
def __init__(
self,
backbone: DistilBertBackbone,
num_classes: int,
preprocessor: Preprocessor = None,
**kwargs
): ...
class DistilBertMaskedLM(MaskedLM):
"""DistilBERT model for masked language modeling."""
def __init__(
self,
backbone: DistilBertBackbone,
preprocessor: Preprocessor = None,
**kwargs
): ...
class DistilBertMaskedLMPreprocessor:
"""Preprocessor for DistilBERT masked language modeling."""
def __init__(
self,
tokenizer: DistilBertTokenizer,
sequence_length: int = 512,
mask_selection_rate: float = 0.15,
mask_token_rate: float = 0.8,
random_token_rate: float = 0.1,
**kwargs
): ...
class DistilBertTextClassifierPreprocessor:
"""Preprocessor for DistilBERT text classification."""
def __init__(
self,
tokenizer: DistilBertTokenizer,
sequence_length: int = 512,
**kwargs
): ...
class DistilBertTokenizer:
"""DistilBERT tokenizer using WordPiece algorithm."""
def __init__(
self,
vocabulary: dict = None,
lowercase: bool = True,
**kwargs
): ...
# Aliases
DistilBertClassifier = DistilBertTextClassifier
DistilBertPreprocessor = DistilBertTextClassifierPreprocessorModern large language models for advanced text generation and understanding.
# Llama
class LlamaBackbone(Backbone): ...
class LlamaCausalLM(CausalLM): ...
class LlamaCausalLMPreprocessor: ...
class LlamaTokenizer: ...
# Llama 3
class Llama3Backbone(Backbone): ...
class Llama3CausalLM(CausalLM): ...
class Llama3CausalLMPreprocessor: ...
class Llama3Tokenizer: ...
# Mistral
class MistralBackbone(Backbone): ...
class MistralCausalLM(CausalLM): ...
class MistralCausalLMPreprocessor: ...
class MistralTokenizer: ...
# Mixtral (Mixture of Experts)
class MixtralBackbone(Backbone): ...
class MixtralCausalLM(CausalLM): ...
class MixtralCausalLMPreprocessor: ...
class MixtralTokenizer: ...
# Gemma
class GemmaBackbone(Backbone): ...
class GemmaCausalLM(CausalLM): ...
class GemmaCausalLMPreprocessor: ...
class GemmaTokenizer: ...
# Gemma 3
class Gemma3Backbone(Backbone): ...
class Gemma3CausalLM(CausalLM): ...
class Gemma3CausalLMPreprocessor: ...
class Gemma3Tokenizer: ...
# BLOOM
class BloomBackbone(Backbone): ...
class BloomCausalLM(CausalLM): ...
class BloomCausalLMPreprocessor: ...
class BloomTokenizer: ...
# OPT
class OPTBackbone(Backbone): ...
class OPTCausalLM(CausalLM): ...
class OPTCausalLMPreprocessor: ...
class OPTTokenizer: ...
# GPT-NeoX
class GPTNeoXBackbone(Backbone): ...
class GPTNeoXCausalLM(CausalLM): ...
class GPTNeoXCausalLMPreprocessor: ...
class GPTNeoXTokenizer: ...
# Falcon
class FalconBackbone(Backbone): ...
class FalconCausalLM(CausalLM): ...
class FalconCausalLMPreprocessor: ...
class FalconTokenizer: ...
# Phi-3
class Phi3Backbone(Backbone): ...
class Phi3CausalLM(CausalLM): ...
class Phi3CausalLMPreprocessor: ...
class Phi3Tokenizer: ...
# Qwen / Qwen 2
class QwenBackbone(Backbone): ...
class QwenCausalLM(CausalLM): ...
class QwenCausalLMPreprocessor: ...
class QwenTokenizer: ...
# Aliases for Qwen 2
Qwen2Backbone = QwenBackbone
Qwen2CausalLM = QwenCausalLM
Qwen2CausalLMPreprocessor = QwenCausalLMPreprocessor
Qwen2Tokenizer = QwenTokenizer
# Qwen 3
class Qwen3Backbone(Backbone): ...
class Qwen3CausalLM(CausalLM): ...
class Qwen3CausalLMPreprocessor: ...
class Qwen3Tokenizer: ...
# Qwen MoE
class QwenMoeBackbone(Backbone): ...
class QwenMoeCausalLM(CausalLM): ...
class QwenMoeCausalLMPreprocessor: ...
class QwenMoeTokenizer: ...Additional text models for specific domains and tasks.
# ALBERT (A Lite BERT)
class AlbertBackbone(Backbone): ...
class AlbertTextClassifier(TextClassifier): ...
class AlbertMaskedLM(MaskedLM): ...
class AlbertMaskedLMPreprocessor: ...
class AlbertTextClassifierPreprocessor: ...
class AlbertTokenizer: ...
# Aliases
AlbertClassifier = AlbertTextClassifier
AlbertPreprocessor = AlbertTextClassifierPreprocessor
# DeBERTa V3 (Decoding-enhanced BERT with Disentangled Attention)
class DebertaV3Backbone(Backbone): ...
class DebertaV3TextClassifier(TextClassifier): ...
class DebertaV3MaskedLM(MaskedLM): ...
class DebertaV3MaskedLMPreprocessor: ...
class DebertaV3TextClassifierPreprocessor: ...
class DebertaV3Tokenizer: ...
# Aliases
DebertaV3Classifier = DebertaV3TextClassifier
DebertaV3Preprocessor = DebertaV3TextClassifierPreprocessor
# ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)
class ElectraBackbone(Backbone): ...
class ElectraTokenizer: ...
# F-Net (Fourier Transform-based Transformer)
class FNetBackbone(Backbone): ...
class FNetTextClassifier(TextClassifier): ...
class FNetMaskedLM(MaskedLM): ...
class FNetMaskedLMPreprocessor: ...
class FNetTextClassifierPreprocessor: ...
class FNetTokenizer: ...
# Aliases
FNetClassifier = FNetTextClassifier
FNetPreprocessor = FNetTextClassifierPreprocessor
# XLM-RoBERTa (Cross-lingual Language Model - RoBERTa)
class XLMRobertaBackbone(Backbone): ...
class XLMRobertaTextClassifier(TextClassifier): ...
class XLMRobertaMaskedLM(MaskedLM): ...
class XLMRobertaMaskedLMPreprocessor: ...
class XLMRobertaTextClassifierPreprocessor: ...
class XLMRobertaTokenizer: ...
# Aliases
XLMRobertaClassifier = XLMRobertaTextClassifier
XLMRobertaPreprocessor = XLMRobertaTextClassifierPreprocessor
# XLNet
class XLNetBackbone(Backbone): ...
# RoFormer V2 (Rotary Position Embedding Transformer V2)
class RoformerV2Backbone(Backbone): ...
class RoformerV2TextClassifier(TextClassifier): ...
class RoformerV2MaskedLM(MaskedLM): ...
class RoformerV2MaskedLMPreprocessor: ...
class RoformerV2TextClassifierPreprocessor: ...
class RoformerV2Tokenizer: ...
# T5 (Text-To-Text Transfer Transformer)
class T5Backbone(Backbone): ...
class T5Preprocessor: ...
class T5Tokenizer: ...
# ESM (Evolutionary Scale Modeling) - Protein Language Models
class ESMBackbone(Backbone): ...
class ESMProteinClassifier: ...
class ESMProteinClassifierPreprocessor: ...
class ESMMaskedPLM: ...
class ESMMaskedPLMPreprocessor: ...
class ESMTokenizer: ...
# Aliases
ESM2Backbone = ESMBackbone
ESM2MaskedPLM = ESMMaskedPLMBase classes for text preprocessing.
class Preprocessor:
"""Base class for preprocessors."""
@classmethod
def from_preset(cls, preset: str, **kwargs): ...
def __call__(self, x, y=None, sample_weight=None): ...
class CausalLMPreprocessor(Preprocessor):
"""Base preprocessor for causal language models."""
def __init__(
self,
tokenizer: Tokenizer,
sequence_length: int = 1024,
add_start_token: bool = False,
add_end_token: bool = False,
**kwargs
): ...
class MaskedLMPreprocessor(Preprocessor):
"""Base preprocessor for masked language models."""
def __init__(
self,
tokenizer: Tokenizer,
sequence_length: int = 512,
mask_selection_rate: float = 0.15,
mask_token_rate: float = 0.8,
random_token_rate: float = 0.1,
**kwargs
): ...
class Seq2SeqLMPreprocessor(Preprocessor):
"""Base preprocessor for sequence-to-sequence models."""
def __init__(
self,
tokenizer: Tokenizer,
encoder_sequence_length: int = 1024,
decoder_sequence_length: int = 1024,
**kwargs
): ...
class TextClassifierPreprocessor(Preprocessor):
"""Base preprocessor for text classification."""
def __init__(
self,
tokenizer: Tokenizer,
sequence_length: int = 512,
**kwargs
): ...import keras_hub
# Load pretrained BERT classifier
classifier = keras_hub.models.BertTextClassifier.from_preset(
"bert_base_en",
num_classes=2 # Binary classification
)
# Compile model
classifier.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
# Prepare data
train_texts = ["This movie is great!", "I didn't like this film."]
train_labels = [1, 0]
# Train
classifier.fit(train_texts, train_labels, epochs=3)
# Predict
predictions = classifier.predict(["A wonderful story!"])
print(predictions)import keras_hub
# Load pretrained GPT-2 model
generator = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Generate text
prompt = "The future of artificial intelligence is"
generated = generator.generate(prompt, max_length=100)
print(generated)
# Control generation with sampling
sampler = keras_hub.samplers.TopKSampler(k=50, temperature=0.8)
generated = generator.generate(prompt, max_length=100, sampler=sampler)
print(generated)import keras_hub
# Load RoBERTa masked LM
model = keras_hub.models.RobertaMaskedLM.from_preset("roberta_base_en")
# Predict masked tokens
text_with_mask = "The capital of France is [MASK]."
predictions = model.predict([text_with_mask])
print(predictions)Install with Tessl CLI
npx tessl i tessl/pypi-keras-hub