CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-transformers

State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow

Overview
Eval results
Files

generation.mddocs/

Generation

Advanced text generation capabilities with multiple decoding strategies, fine-grained control over output, and support for conversational AI. The generation system provides flexible interfaces for autoregressive text generation with extensive customization options.

Capabilities

Generation Mixin

Core generation functionality available on all generative models.

class GenerationMixin:
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        use_model_defaults: Optional[bool] = None,
        custom_generate: Optional[Union[str, Callable]] = None,
        **kwargs
    ) -> Union[GenerateOutput, torch.LongTensor]:
        """
        Generate sequences using the model.
        
        Args:
            inputs: Input token IDs
            generation_config: Generation configuration
            logits_processor: Custom logits processors
            stopping_criteria: Custom stopping criteria
            prefix_allowed_tokens_fn: Constrain generation to allowed tokens
            synced_gpus: Synchronize GPUs in distributed setting
            assistant_model: Assistant model for speculative decoding
            streamer: Streamer for real-time generation output
            negative_prompt_ids: Negative prompt for guidance
            negative_prompt_attention_mask: Attention mask for negative prompt
            use_model_defaults: Use model's default generation config
            custom_generate: Custom generation function or string identifier
            **kwargs: Additional generation parameters
        
        Returns:
            Generated token sequences
        """
    
    def beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:
        """Beam search decoding."""
    
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:
        """Beam search with sampling."""
    
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:
        """Diverse beam search with groups."""
    
    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateSampleOutput, torch.LongTensor]:
        """Sampling-based generation."""
    
    def greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateGreedyOutput, torch.LongTensor]:
        """Greedy decoding."""
    
    def contrastive_search(
        self,
        input_ids: torch.LongTensor,
        penalty_alpha: float,
        top_k: int,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **kwargs
    ) -> Union[GenerateContrastiveOutput, torch.LongTensor]:
        """Contrastive search decoding."""

Generation Configuration

Comprehensive configuration for generation parameters and strategies.

class GenerationConfig:
    def __init__(
        self,
        # Length parameters
        max_length: int = 20,
        max_new_tokens: Optional[int] = None,
        min_length: int = 0,
        min_new_tokens: Optional[int] = None,
        early_stopping: Union[bool, str] = False,
        max_time: Optional[float] = None,
        
        # Generation strategy
        do_sample: bool = False,
        num_beams: int = 1,
        num_beam_groups: int = 1,
        penalty_alpha: Optional[float] = None,
        use_cache: bool = True,
        
        # Sampling parameters
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 1.0,
        typical_p: float = 1.0,
        epsilon_cutoff: float = 0.0,
        eta_cutoff: float = 0.0,
        diversity_penalty: float = 0.0,
        
        # Repetition parameters
        repetition_penalty: float = 1.0,
        no_repeat_ngram_size: int = 0,
        encoder_no_repeat_ngram_size: int = 0,
        
        # Special tokens
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        decoder_start_token_id: Optional[int] = None,
        
        # Generation control
        num_return_sequences: int = 1,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_scores: bool = False,
        return_dict_in_generate: bool = False,
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[Union[int, List[int]]] = None,
        remove_invalid_values: bool = False,
        exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
        suppress_tokens: Optional[List[int]] = None,
        begin_suppress_tokens: Optional[List[int]] = None,
        forced_decoder_ids: Optional[List[List[int]]] = None,
        
        # Sequence bias
        sequence_bias: Optional[Dict[Tuple[int], float]] = None,
        guidance_scale: Optional[float] = None,
        low_memory: Optional[bool] = None,
        
        # Watermarking
        watermarking_config: Optional[Dict] = None,
        
        **kwargs
    ):
        """
        Configuration for text generation.
        
        Key parameters:
            max_length: Maximum total sequence length
            max_new_tokens: Maximum number of new tokens to generate
            min_length: Minimum sequence length
            do_sample: Use sampling instead of greedy/beam search
            num_beams: Number of beams for beam search
            temperature: Sampling temperature (higher = more random)
            top_k: Keep only top-k tokens for sampling
            top_p: Nucleus sampling probability threshold
            repetition_penalty: Penalty for repeated tokens
            no_repeat_ngram_size: Prevent repeating n-grams
            num_return_sequences: Number of sequences to generate
        """
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name: str,
        config_file_name: Optional[str] = None,
        cache_dir: Optional[str] = None,
        force_download: bool = False,
        **kwargs
    ) -> "GenerationConfig":
        """Load generation config from pretrained model."""
    
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        config_file_name: Optional[str] = None,
        push_to_hub: bool = False,
        **kwargs
    ) -> None:
        """Save generation config to directory."""
    
    def update(self, **kwargs) -> None:
        """Update configuration with new parameters."""

Beam Search Scoring

Advanced beam search with scoring and ranking capabilities.

class BeamScorer:
    """Base class for beam search scoring."""
    
    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        **kwargs
    ) -> Tuple[torch.Tensor]:
        """Process beam candidates."""
    
    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        **kwargs
    ) -> torch.LongTensor:
        """Finalize beam search."""

class BeamSearchScorer(BeamScorer):
    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        device: torch.device,
        length_penalty: Optional[float] = 1.0,
        do_early_stopping: Optional[bool] = False,
        num_beam_hyps_to_keep: Optional[int] = 1,
        num_beam_groups: Optional[int] = 1,
        **kwargs
    ):
        """
        Beam search scorer with length penalty and early stopping.
        
        Args:
            batch_size: Batch size
            num_beams: Number of beams
            device: Device to run on
            length_penalty: Length penalty for beam scoring
            do_early_stopping: Stop when finding complete sequences
            num_beam_hyps_to_keep: Number of hypotheses to keep
            num_beam_groups: Number of beam groups for diverse search
        """

class ConstrainedBeamSearchScorer(BeamScorer):
    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        device: torch.device,
        constraints: List[Constraint],
        **kwargs
    ):
        """Beam search with lexical constraints."""

Logits Processing

Customizable logits processing for generation control.

class LogitsProcessor:
    """Base class for logits processors."""
    
    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        """Process logits before sampling/selection."""

class LogitsProcessorList(List[LogitsProcessor]):
    """List of logits processors applied sequentially."""

class TemperatureLogitsWarper(LogitsProcessor):
    def __init__(self, temperature: float):
        """Apply temperature scaling to logits."""

class TopKLogitsWarper(LogitsProcessor):
    def __init__(
        self,
        top_k: int,
        filter_value: float = float("-inf"),
        min_tokens_to_keep: int = 1
    ):
        """Keep only top-k tokens, set others to filter_value."""

class TopPLogitsWarper(LogitsProcessor):
    def __init__(
        self,
        top_p: float,
        filter_value: float = float("-inf"),
        min_tokens_to_keep: int = 1
    ):
        """Nucleus sampling: keep tokens with cumulative probability <= top_p."""

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    def __init__(self, penalty: float):
        """Apply repetition penalty to previously generated tokens."""

class NoRepeatNGramLogitsProcessor(LogitsProcessor):
    def __init__(self, ngram_size: int):
        """Prevent repeating n-grams."""

Stopping Criteria

Flexible stopping conditions for generation.

class StoppingCriteria:
    """Base class for stopping criteria."""
    
    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
        **kwargs
    ) -> bool:
        """Check if generation should stop."""

class StoppingCriteriaList(List[StoppingCriteria]):
    """List of stopping criteria (OR logic)."""

class MaxLengthCriteria(StoppingCriteria):
    def __init__(self, max_length: int):
        """Stop when reaching maximum length."""

class MaxTimeCriteria(StoppingCriteria):
    def __init__(self, max_time: float):
        """Stop when exceeding maximum time."""

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(
        self,
        keywords: List[str],
        tokenizer: PreTrainedTokenizer
    ):
        """Stop when generating specific keywords."""

Generation Output Types

Structured outputs from different generation methods.

class GenerateOutput:
    """Base output type for generation."""
    sequences: torch.LongTensor
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

class GenerateBeamOutput(GenerateOutput):
    """Output from beam search generation."""
    sequences_scores: Optional[torch.FloatTensor] = None
    beam_indices: Optional[torch.LongTensor] = None

class GenerateSampleOutput(GenerateOutput):
    """Output from sampling generation."""

class GenerateGreedyOutput(GenerateOutput):
    """Output from greedy generation."""

Streaming Generation

Real-time streaming of generated text.

class BaseStreamer:
    """Base class for generation streamers."""
    
    def put(self, value: torch.LongTensor) -> None:
        """Process new generated tokens."""
    
    def end(self) -> None:
        """Signal end of generation."""

class TextStreamer(BaseStreamer):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        skip_prompt: bool = False,
        skip_special_tokens: bool = False,
        **decode_kwargs
    ):
        """
        Stream generated text to stdout.
        
        Args:
            tokenizer: Tokenizer for decoding
            skip_prompt: Skip printing the input prompt
            skip_special_tokens: Skip special tokens in output
            **decode_kwargs: Arguments for tokenizer.decode()
        """

class TextIteratorStreamer(BaseStreamer):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        skip_prompt: bool = False,
        timeout: Optional[float] = None,
        **decode_kwargs
    ):
        """
        Stream generated text through iterator interface.
        
        Args:
            tokenizer: Tokenizer for decoding
            skip_prompt: Skip the input prompt
            timeout: Timeout for iteration
            **decode_kwargs: Arguments for tokenizer.decode()
        """
    
    def __iter__(self) -> Iterator[str]:
        """Iterate over generated text chunks."""

Generation Examples

Common generation patterns and use cases:

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Basic generation
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Sampling with temperature
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    do_sample=True,
    temperature=0.8,
    top_k=50,
    top_p=0.9
)

# Beam search
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    num_beams=5,
    early_stopping=True
)

# Multiple sequences
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    num_return_sequences=3,
    do_sample=True,
    temperature=0.8
)

# With custom generation config
gen_config = GenerationConfig(
    max_new_tokens=100,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.1,
    no_repeat_ngram_size=2
)

outputs = model.generate(**inputs, generation_config=gen_config)

# Streaming generation
from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=True)

outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    streamer=streamer
)

# Constrained generation
from transformers import KeywordsStoppingCriteria
stop_words = ["END", "STOP"]
stopping_criteria = KeywordsStoppingCriteria(stop_words, tokenizer)

outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    stopping_criteria=[stopping_criteria]
)

Install with Tessl CLI

npx tessl i tessl/pypi-transformers

docs

feature-extraction.md

generation.md

index.md

models.md

optimization.md

pipelines.md

tokenization.md

training.md

tile.json