Pretrained models for Keras with multi-framework compatibility.
—
Sampling strategies for controlling text generation behavior in language models. Keras Hub provides various sampling methods to balance between quality, diversity, and controllability in generated text.
Foundation classes for text generation sampling.
class Sampler:
"""Base class for all samplers."""
def __init__(self, **kwargs): ...
def __call__(
self,
next_token_logits,
prompt_tokens,
generated_tokens,
**kwargs
): ...
def get_next_token(self, probabilities): ...Samplers that produce deterministic outputs given the same input.
class GreedySampler(Sampler):
"""
Greedy sampling always selects the token with highest probability.
Produces deterministic but potentially repetitive outputs.
"""
def __init__(self, **kwargs): ...
class BeamSampler(Sampler):
"""
Beam search maintains multiple candidate sequences and selects
the sequence with highest overall probability.
"""
def __init__(
self,
num_beams: int = 5,
return_all_beams: bool = False,
**kwargs
): ...Samplers that introduce randomness for more diverse outputs.
class RandomSampler(Sampler):
"""
Random sampling selects tokens according to their probability distribution.
Higher temperature increases randomness.
"""
def __init__(
self,
temperature: float = 1.0,
seed: int = None,
**kwargs
): ...
class TopKSampler(Sampler):
"""
Top-k sampling considers only the k most likely tokens at each step.
Balances quality and diversity by filtering low-probability tokens.
"""
def __init__(
self,
k: int = 50,
temperature: float = 1.0,
seed: int = None,
**kwargs
): ...
class TopPSampler(Sampler):
"""
Top-p (nucleus) sampling considers tokens whose cumulative probability
is within the top p fraction. Adapts the number of considered tokens
based on the probability distribution.
"""
def __init__(
self,
p: float = 0.9,
temperature: float = 1.0,
seed: int = None,
**kwargs
): ...More sophisticated sampling strategies for improved generation quality.
class ContrastiveSampler(Sampler):
"""
Contrastive search balances high probability and low repetition
by penalizing tokens that are too similar to previously generated tokens.
"""
def __init__(
self,
k: int = 4,
alpha: float = 0.6,
**kwargs
): ...Utilities for working with samplers programmatically.
def serialize(sampler: Sampler) -> dict:
"""
Serialize a sampler instance to a dictionary.
Args:
sampler: The sampler instance to serialize
Returns:
Dictionary representation of the sampler
"""
...
def deserialize(config: dict) -> Sampler:
"""
Deserialize a sampler from a dictionary configuration.
Args:
config: Dictionary configuration of the sampler
Returns:
Sampler instance
"""
...
def get(identifier) -> Sampler:
"""
Get a sampler by name or return existing sampler instance.
Args:
identifier: String name or sampler instance
Returns:
Sampler instance
"""
...import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Create greedy sampler
sampler = keras_hub.samplers.GreedySampler()
# Generate text deterministically
prompt = "The future of artificial intelligence"
output = model.generate(prompt, max_length=50, sampler=sampler)
print("Greedy output:", output)import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Low temperature for more focused generation
low_temp_sampler = keras_hub.samplers.RandomSampler(temperature=0.3)
output_focused = model.generate(
"The weather today is",
max_length=30,
sampler=low_temp_sampler
)
# High temperature for more creative generation
high_temp_sampler = keras_hub.samplers.RandomSampler(temperature=1.5)
output_creative = model.generate(
"The weather today is",
max_length=30,
sampler=high_temp_sampler
)
print("Focused output:", output_focused)
print("Creative output:", output_creative)import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Top-k sampling with different k values
small_k_sampler = keras_hub.samplers.TopKSampler(k=10, temperature=0.8)
large_k_sampler = keras_hub.samplers.TopKSampler(k=100, temperature=0.8)
prompt = "In the distant future"
# More conservative generation (smaller k)
output_conservative = model.generate(prompt, max_length=40, sampler=small_k_sampler)
# More diverse generation (larger k)
output_diverse = model.generate(prompt, max_length=40, sampler=large_k_sampler)
print("Conservative (k=10):", output_conservative)
print("Diverse (k=100):", output_diverse)import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Top-p sampling adapts to probability distribution
sampler = keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8)
# Generate multiple outputs to see diversity
prompt = "Once upon a time"
for i in range(3):
output = model.generate(prompt, max_length=25, sampler=sampler)
print(f"Output {i+1}: {output}")import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Beam search with different beam sizes
beam_sampler = keras_hub.samplers.BeamSampler(
num_beams=5,
return_all_beams=False # Return only best beam
)
prompt = "The most important discovery in science"
output = model.generate(prompt, max_length=35, sampler=beam_sampler)
print("Beam search output:", output)
# Return all beams to see alternatives
all_beams_sampler = keras_hub.samplers.BeamSampler(
num_beams=3,
return_all_beams=True
)
all_outputs = model.generate(prompt, max_length=25, sampler=all_beams_sampler)
for i, beam_output in enumerate(all_outputs):
print(f"Beam {i+1}: {beam_output}")import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Contrastive search balances probability and novelty
sampler = keras_hub.samplers.ContrastiveSampler(
k=4, # Number of top tokens to consider
alpha=0.6 # Balance between probability and novelty
)
prompt = "Artificial intelligence will change the world by"
output = model.generate(prompt, max_length=50, sampler=sampler)
print("Contrastive search output:", output)import keras_hub
# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Define different samplers
samplers = {
"Greedy": keras_hub.samplers.GreedySampler(),
"Random (T=0.8)": keras_hub.samplers.RandomSampler(temperature=0.8),
"Top-k (k=50)": keras_hub.samplers.TopKSampler(k=50, temperature=0.8),
"Top-p (p=0.9)": keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8),
"Contrastive": keras_hub.samplers.ContrastiveSampler(k=4, alpha=0.6)
}
prompt = "The key to happiness is"
# Generate with each sampler
for name, sampler in samplers.items():
output = model.generate(prompt, max_length=30, sampler=sampler)
print(f"{name}: {output}")import keras_hub
# Create a sampler
original_sampler = keras_hub.samplers.TopKSampler(k=40, temperature=0.7)
# Serialize to dictionary
config = keras_hub.samplers.serialize(original_sampler)
print("Serialized config:", config)
# Deserialize back to sampler
restored_sampler = keras_hub.samplers.deserialize(config)
# Use restored sampler
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
output = model.generate("Hello world", max_length=20, sampler=restored_sampler)
print("Generated with restored sampler:", output)import keras_hub
# Get sampler by string identifier
greedy = keras_hub.samplers.get("greedy")
random = keras_hub.samplers.get("random")
# Get existing sampler instance (returns same instance)
top_k = keras_hub.samplers.TopKSampler(k=50)
same_sampler = keras_hub.samplers.get(top_k)
print("Greedy sampler:", type(greedy).__name__)
print("Random sampler:", type(random).__name__)
print("Same instance:", top_k is same_sampler)import keras_hub
import numpy as np
# Load model and get logits manually
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Get next token logits for a prompt
prompt_tokens = model.preprocessor.tokenizer(["Hello world"])
logits = model.backbone(prompt_tokens)[:, -1, :] # Last token logits
# Apply different samplers to the same logits
samplers = [
keras_hub.samplers.GreedySampler(),
keras_hub.samplers.TopKSampler(k=10),
keras_hub.samplers.TopPSampler(p=0.8)
]
for sampler in samplers:
# Sample next token
next_token = sampler(logits, prompt_tokens, generated_tokens=None)
print(f"{type(sampler).__name__}: token {next_token}")Install with Tessl CLI
npx tessl i tessl/pypi-keras-hub