CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-keras-hub

Pretrained models for Keras with multi-framework compatibility.

Pending
Overview
Eval results
Files

text-generation-sampling.mddocs/

Text Generation Sampling

Sampling strategies for controlling text generation behavior in language models. Keras Hub provides various sampling methods to balance between quality, diversity, and controllability in generated text.

Capabilities

Base Classes

Foundation classes for text generation sampling.

class Sampler:
    """Base class for all samplers."""
    def __init__(self, **kwargs): ...
    
    def __call__(
        self,
        next_token_logits,
        prompt_tokens,
        generated_tokens,
        **kwargs
    ): ...
    
    def get_next_token(self, probabilities): ...

Deterministic Sampling

Samplers that produce deterministic outputs given the same input.

class GreedySampler(Sampler):
    """
    Greedy sampling always selects the token with highest probability.
    Produces deterministic but potentially repetitive outputs.
    """
    def __init__(self, **kwargs): ...

class BeamSampler(Sampler):
    """
    Beam search maintains multiple candidate sequences and selects
    the sequence with highest overall probability.
    """
    def __init__(
        self,
        num_beams: int = 5,
        return_all_beams: bool = False,
        **kwargs
    ): ...

Stochastic Sampling

Samplers that introduce randomness for more diverse outputs.

class RandomSampler(Sampler):
    """
    Random sampling selects tokens according to their probability distribution.
    Higher temperature increases randomness.
    """
    def __init__(
        self,
        temperature: float = 1.0,
        seed: int = None,
        **kwargs
    ): ...

class TopKSampler(Sampler):
    """
    Top-k sampling considers only the k most likely tokens at each step.
    Balances quality and diversity by filtering low-probability tokens.
    """
    def __init__(
        self,
        k: int = 50,
        temperature: float = 1.0,
        seed: int = None,
        **kwargs
    ): ...

class TopPSampler(Sampler):
    """
    Top-p (nucleus) sampling considers tokens whose cumulative probability
    is within the top p fraction. Adapts the number of considered tokens
    based on the probability distribution.
    """
    def __init__(
        self,
        p: float = 0.9,
        temperature: float = 1.0,
        seed: int = None,
        **kwargs
    ): ...

Advanced Sampling

More sophisticated sampling strategies for improved generation quality.

class ContrastiveSampler(Sampler):
    """
    Contrastive search balances high probability and low repetition
    by penalizing tokens that are too similar to previously generated tokens.
    """
    def __init__(
        self,
        k: int = 4,
        alpha: float = 0.6,
        **kwargs
    ): ...

Sampler Utilities

Utilities for working with samplers programmatically.

def serialize(sampler: Sampler) -> dict:
    """
    Serialize a sampler instance to a dictionary.
    
    Args:
        sampler: The sampler instance to serialize
        
    Returns:
        Dictionary representation of the sampler
    """
    ...

def deserialize(config: dict) -> Sampler:
    """
    Deserialize a sampler from a dictionary configuration.
    
    Args:
        config: Dictionary configuration of the sampler
        
    Returns:
        Sampler instance
    """
    ...

def get(identifier) -> Sampler:
    """
    Get a sampler by name or return existing sampler instance.
    
    Args:
        identifier: String name or sampler instance
        
    Returns:
        Sampler instance
    """
    ...

Usage Examples

Greedy Sampling for Deterministic Output

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Create greedy sampler
sampler = keras_hub.samplers.GreedySampler()

# Generate text deterministically
prompt = "The future of artificial intelligence"
output = model.generate(prompt, max_length=50, sampler=sampler)
print("Greedy output:", output)

Random Sampling with Temperature Control

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Low temperature for more focused generation
low_temp_sampler = keras_hub.samplers.RandomSampler(temperature=0.3)
output_focused = model.generate(
    "The weather today is",
    max_length=30,
    sampler=low_temp_sampler
)

# High temperature for more creative generation
high_temp_sampler = keras_hub.samplers.RandomSampler(temperature=1.5)
output_creative = model.generate(
    "The weather today is",
    max_length=30,
    sampler=high_temp_sampler
)

print("Focused output:", output_focused)
print("Creative output:", output_creative)

Top-k Sampling for Quality-Diversity Balance

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Top-k sampling with different k values
small_k_sampler = keras_hub.samplers.TopKSampler(k=10, temperature=0.8)
large_k_sampler = keras_hub.samplers.TopKSampler(k=100, temperature=0.8)

prompt = "In the distant future"

# More conservative generation (smaller k)
output_conservative = model.generate(prompt, max_length=40, sampler=small_k_sampler)

# More diverse generation (larger k)  
output_diverse = model.generate(prompt, max_length=40, sampler=large_k_sampler)

print("Conservative (k=10):", output_conservative)
print("Diverse (k=100):", output_diverse)

Top-p (Nucleus) Sampling

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Top-p sampling adapts to probability distribution
sampler = keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8)

# Generate multiple outputs to see diversity
prompt = "Once upon a time"
for i in range(3):
    output = model.generate(prompt, max_length=25, sampler=sampler)
    print(f"Output {i+1}: {output}")

Beam Search for Best Overall Sequence

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Beam search with different beam sizes
beam_sampler = keras_hub.samplers.BeamSampler(
    num_beams=5,
    return_all_beams=False  # Return only best beam
)

prompt = "The most important discovery in science"
output = model.generate(prompt, max_length=35, sampler=beam_sampler)
print("Beam search output:", output)

# Return all beams to see alternatives
all_beams_sampler = keras_hub.samplers.BeamSampler(
    num_beams=3,
    return_all_beams=True
)

all_outputs = model.generate(prompt, max_length=25, sampler=all_beams_sampler)
for i, beam_output in enumerate(all_outputs):
    print(f"Beam {i+1}: {beam_output}")

Contrastive Search for Reducing Repetition

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Contrastive search balances probability and novelty
sampler = keras_hub.samplers.ContrastiveSampler(
    k=4,      # Number of top tokens to consider
    alpha=0.6  # Balance between probability and novelty
)

prompt = "Artificial intelligence will change the world by"
output = model.generate(prompt, max_length=50, sampler=sampler)
print("Contrastive search output:", output)

Comparing Different Sampling Methods

import keras_hub

# Load model
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Define different samplers
samplers = {
    "Greedy": keras_hub.samplers.GreedySampler(),
    "Random (T=0.8)": keras_hub.samplers.RandomSampler(temperature=0.8),
    "Top-k (k=50)": keras_hub.samplers.TopKSampler(k=50, temperature=0.8),
    "Top-p (p=0.9)": keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8),
    "Contrastive": keras_hub.samplers.ContrastiveSampler(k=4, alpha=0.6)
}

prompt = "The key to happiness is"

# Generate with each sampler
for name, sampler in samplers.items():
    output = model.generate(prompt, max_length=30, sampler=sampler)
    print(f"{name}: {output}")

Serializing and Deserializing Samplers

import keras_hub

# Create a sampler
original_sampler = keras_hub.samplers.TopKSampler(k=40, temperature=0.7)

# Serialize to dictionary
config = keras_hub.samplers.serialize(original_sampler)
print("Serialized config:", config)

# Deserialize back to sampler
restored_sampler = keras_hub.samplers.deserialize(config)

# Use restored sampler
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
output = model.generate("Hello world", max_length=20, sampler=restored_sampler)
print("Generated with restored sampler:", output)

Getting Samplers by Name

import keras_hub

# Get sampler by string identifier
greedy = keras_hub.samplers.get("greedy")
random = keras_hub.samplers.get("random")

# Get existing sampler instance (returns same instance)
top_k = keras_hub.samplers.TopKSampler(k=50)
same_sampler = keras_hub.samplers.get(top_k)

print("Greedy sampler:", type(greedy).__name__)
print("Random sampler:", type(random).__name__)
print("Same instance:", top_k is same_sampler)

Custom Sampling with Manual Control

import keras_hub
import numpy as np

# Load model and get logits manually
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Get next token logits for a prompt
prompt_tokens = model.preprocessor.tokenizer(["Hello world"])
logits = model.backbone(prompt_tokens)[:, -1, :]  # Last token logits

# Apply different samplers to the same logits
samplers = [
    keras_hub.samplers.GreedySampler(),
    keras_hub.samplers.TopKSampler(k=10),
    keras_hub.samplers.TopPSampler(p=0.8)
]

for sampler in samplers:
    # Sample next token
    next_token = sampler(logits, prompt_tokens, generated_tokens=None)
    print(f"{type(sampler).__name__}: token {next_token}")

Install with Tessl CLI

npx tessl i tessl/pypi-keras-hub

docs

audio-models.md

evaluation-metrics.md

generative-models.md

image-models.md

index.md

layers-components.md

multimodal-models.md

text-generation-sampling.md

text-models.md

tokenizers.md

utilities-helpers.md

tile.json