State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Advanced text generation capabilities with multiple decoding strategies, fine-grained control over output, and support for conversational AI. The generation system provides flexible interfaces for autoregressive text generation with extensive customization options.
Core generation functionality available on all generative models.
class GenerationMixin:
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[Union[str, Callable]] = None,
**kwargs
) -> Union[GenerateOutput, torch.LongTensor]:
"""
Generate sequences using the model.
Args:
inputs: Input token IDs
generation_config: Generation configuration
logits_processor: Custom logits processors
stopping_criteria: Custom stopping criteria
prefix_allowed_tokens_fn: Constrain generation to allowed tokens
synced_gpus: Synchronize GPUs in distributed setting
assistant_model: Assistant model for speculative decoding
streamer: Streamer for real-time generation output
negative_prompt_ids: Negative prompt for guidance
negative_prompt_attention_mask: Attention mask for negative prompt
use_model_defaults: Use model's default generation config
custom_generate: Custom generation function or string identifier
**kwargs: Additional generation parameters
Returns:
Generated token sequences
"""
def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateBeamOutput, torch.LongTensor]:
"""Beam search decoding."""
def beam_sample(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateBeamOutput, torch.LongTensor]:
"""Beam search with sampling."""
def group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateBeamOutput, torch.LongTensor]:
"""Diverse beam search with groups."""
def sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateSampleOutput, torch.LongTensor]:
"""Sampling-based generation."""
def greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateGreedyOutput, torch.LongTensor]:
"""Greedy decoding."""
def contrastive_search(
self,
input_ids: torch.LongTensor,
penalty_alpha: float,
top_k: int,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
) -> Union[GenerateContrastiveOutput, torch.LongTensor]:
"""Contrastive search decoding."""Comprehensive configuration for generation parameters and strategies.
class GenerationConfig:
def __init__(
self,
# Length parameters
max_length: int = 20,
max_new_tokens: Optional[int] = None,
min_length: int = 0,
min_new_tokens: Optional[int] = None,
early_stopping: Union[bool, str] = False,
max_time: Optional[float] = None,
# Generation strategy
do_sample: bool = False,
num_beams: int = 1,
num_beam_groups: int = 1,
penalty_alpha: Optional[float] = None,
use_cache: bool = True,
# Sampling parameters
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
typical_p: float = 1.0,
epsilon_cutoff: float = 0.0,
eta_cutoff: float = 0.0,
diversity_penalty: float = 0.0,
# Repetition parameters
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
encoder_no_repeat_ngram_size: int = 0,
# Special tokens
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
decoder_start_token_id: Optional[int] = None,
# Generation control
num_return_sequences: int = 1,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_scores: bool = False,
return_dict_in_generate: bool = False,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[Union[int, List[int]]] = None,
remove_invalid_values: bool = False,
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
# Sequence bias
sequence_bias: Optional[Dict[Tuple[int], float]] = None,
guidance_scale: Optional[float] = None,
low_memory: Optional[bool] = None,
# Watermarking
watermarking_config: Optional[Dict] = None,
**kwargs
):
"""
Configuration for text generation.
Key parameters:
max_length: Maximum total sequence length
max_new_tokens: Maximum number of new tokens to generate
min_length: Minimum sequence length
do_sample: Use sampling instead of greedy/beam search
num_beams: Number of beams for beam search
temperature: Sampling temperature (higher = more random)
top_k: Keep only top-k tokens for sampling
top_p: Nucleus sampling probability threshold
repetition_penalty: Penalty for repeated tokens
no_repeat_ngram_size: Prevent repeating n-grams
num_return_sequences: Number of sequences to generate
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name: str,
config_file_name: Optional[str] = None,
cache_dir: Optional[str] = None,
force_download: bool = False,
**kwargs
) -> "GenerationConfig":
"""Load generation config from pretrained model."""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[str] = None,
push_to_hub: bool = False,
**kwargs
) -> None:
"""Save generation config to directory."""
def update(self, **kwargs) -> None:
"""Update configuration with new parameters."""Advanced beam search with scoring and ranking capabilities.
class BeamScorer:
"""Base class for beam search scoring."""
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> Tuple[torch.Tensor]:
"""Process beam candidates."""
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
**kwargs
) -> torch.LongTensor:
"""Finalize beam search."""
class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs
):
"""
Beam search scorer with length penalty and early stopping.
Args:
batch_size: Batch size
num_beams: Number of beams
device: Device to run on
length_penalty: Length penalty for beam scoring
do_early_stopping: Stop when finding complete sequences
num_beam_hyps_to_keep: Number of hypotheses to keep
num_beam_groups: Number of beam groups for diverse search
"""
class ConstrainedBeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
constraints: List[Constraint],
**kwargs
):
"""Beam search with lexical constraints."""Customizable logits processing for generation control.
class LogitsProcessor:
"""Base class for logits processors."""
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor
) -> torch.FloatTensor:
"""Process logits before sampling/selection."""
class LogitsProcessorList(List[LogitsProcessor]):
"""List of logits processors applied sequentially."""
class TemperatureLogitsWarper(LogitsProcessor):
def __init__(self, temperature: float):
"""Apply temperature scaling to logits."""
class TopKLogitsWarper(LogitsProcessor):
def __init__(
self,
top_k: int,
filter_value: float = float("-inf"),
min_tokens_to_keep: int = 1
):
"""Keep only top-k tokens, set others to filter_value."""
class TopPLogitsWarper(LogitsProcessor):
def __init__(
self,
top_p: float,
filter_value: float = float("-inf"),
min_tokens_to_keep: int = 1
):
"""Nucleus sampling: keep tokens with cumulative probability <= top_p."""
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float):
"""Apply repetition penalty to previously generated tokens."""
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int):
"""Prevent repeating n-grams."""Flexible stopping conditions for generation.
class StoppingCriteria:
"""Base class for stopping criteria."""
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs
) -> bool:
"""Check if generation should stop."""
class StoppingCriteriaList(List[StoppingCriteria]):
"""List of stopping criteria (OR logic)."""
class MaxLengthCriteria(StoppingCriteria):
def __init__(self, max_length: int):
"""Stop when reaching maximum length."""
class MaxTimeCriteria(StoppingCriteria):
def __init__(self, max_time: float):
"""Stop when exceeding maximum time."""
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(
self,
keywords: List[str],
tokenizer: PreTrainedTokenizer
):
"""Stop when generating specific keywords."""Structured outputs from different generation methods.
class GenerateOutput:
"""Base output type for generation."""
sequences: torch.LongTensor
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
class GenerateBeamOutput(GenerateOutput):
"""Output from beam search generation."""
sequences_scores: Optional[torch.FloatTensor] = None
beam_indices: Optional[torch.LongTensor] = None
class GenerateSampleOutput(GenerateOutput):
"""Output from sampling generation."""
class GenerateGreedyOutput(GenerateOutput):
"""Output from greedy generation."""Real-time streaming of generated text.
class BaseStreamer:
"""Base class for generation streamers."""
def put(self, value: torch.LongTensor) -> None:
"""Process new generated tokens."""
def end(self) -> None:
"""Signal end of generation."""
class TextStreamer(BaseStreamer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
skip_prompt: bool = False,
skip_special_tokens: bool = False,
**decode_kwargs
):
"""
Stream generated text to stdout.
Args:
tokenizer: Tokenizer for decoding
skip_prompt: Skip printing the input prompt
skip_special_tokens: Skip special tokens in output
**decode_kwargs: Arguments for tokenizer.decode()
"""
class TextIteratorStreamer(BaseStreamer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs
):
"""
Stream generated text through iterator interface.
Args:
tokenizer: Tokenizer for decoding
skip_prompt: Skip the input prompt
timeout: Timeout for iteration
**decode_kwargs: Arguments for tokenizer.decode()
"""
def __iter__(self) -> Iterator[str]:
"""Iterate over generated text chunks."""Common generation patterns and use cases:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Basic generation
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Sampling with temperature
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.9
)
# Beam search
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_beams=5,
early_stopping=True
)
# Multiple sequences
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_return_sequences=3,
do_sample=True,
temperature=0.8
)
# With custom generation config
gen_config = GenerationConfig(
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
no_repeat_ngram_size=2
)
outputs = model.generate(**inputs, generation_config=gen_config)
# Streaming generation
from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=True)
outputs = model.generate(
**inputs,
max_new_tokens=50,
streamer=streamer
)
# Constrained generation
from transformers import KeywordsStoppingCriteria
stop_words = ["END", "STOP"]
stopping_criteria = KeywordsStoppingCriteria(stop_words, tokenizer)
outputs = model.generate(
**inputs,
max_new_tokens=50,
stopping_criteria=[stopping_criteria]
)Install with Tessl CLI
npx tessl i tessl/pypi-transformers