0
# Generation
1
2
Advanced text generation capabilities with multiple decoding strategies, fine-grained control over output, and support for conversational AI. The generation system provides flexible interfaces for autoregressive text generation with extensive customization options.
3
4
## Capabilities
5
6
### Generation Mixin
7
8
Core generation functionality available on all generative models.
9
10
```python { .api }
11
class GenerationMixin:
12
def generate(
13
self,
14
inputs: Optional[torch.Tensor] = None,
15
generation_config: Optional[GenerationConfig] = None,
16
logits_processor: Optional[LogitsProcessorList] = None,
17
stopping_criteria: Optional[StoppingCriteriaList] = None,
18
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
19
synced_gpus: Optional[bool] = None,
20
assistant_model: Optional["PreTrainedModel"] = None,
21
streamer: Optional["BaseStreamer"] = None,
22
negative_prompt_ids: Optional[torch.Tensor] = None,
23
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
24
use_model_defaults: Optional[bool] = None,
25
custom_generate: Optional[Union[str, Callable]] = None,
26
**kwargs
27
) -> Union[GenerateOutput, torch.LongTensor]:
28
"""
29
Generate sequences using the model.
30
31
Args:
32
inputs: Input token IDs
33
generation_config: Generation configuration
34
logits_processor: Custom logits processors
35
stopping_criteria: Custom stopping criteria
36
prefix_allowed_tokens_fn: Constrain generation to allowed tokens
37
synced_gpus: Synchronize GPUs in distributed setting
38
assistant_model: Assistant model for speculative decoding
39
streamer: Streamer for real-time generation output
40
negative_prompt_ids: Negative prompt for guidance
41
negative_prompt_attention_mask: Attention mask for negative prompt
42
use_model_defaults: Use model's default generation config
43
custom_generate: Custom generation function or string identifier
44
**kwargs: Additional generation parameters
45
46
Returns:
47
Generated token sequences
48
"""
49
50
def beam_search(
51
self,
52
input_ids: torch.LongTensor,
53
beam_scorer: BeamScorer,
54
logits_processor: Optional[LogitsProcessorList] = None,
55
stopping_criteria: Optional[StoppingCriteriaList] = None,
56
**kwargs
57
) -> Union[GenerateBeamOutput, torch.LongTensor]:
58
"""Beam search decoding."""
59
60
def beam_sample(
61
self,
62
input_ids: torch.LongTensor,
63
beam_scorer: BeamScorer,
64
logits_processor: Optional[LogitsProcessorList] = None,
65
stopping_criteria: Optional[StoppingCriteriaList] = None,
66
**kwargs
67
) -> Union[GenerateBeamOutput, torch.LongTensor]:
68
"""Beam search with sampling."""
69
70
def group_beam_search(
71
self,
72
input_ids: torch.LongTensor,
73
beam_scorer: BeamScorer,
74
logits_processor: Optional[LogitsProcessorList] = None,
75
stopping_criteria: Optional[StoppingCriteriaList] = None,
76
**kwargs
77
) -> Union[GenerateBeamOutput, torch.LongTensor]:
78
"""Diverse beam search with groups."""
79
80
def sample(
81
self,
82
input_ids: torch.LongTensor,
83
logits_processor: Optional[LogitsProcessorList] = None,
84
stopping_criteria: Optional[StoppingCriteriaList] = None,
85
**kwargs
86
) -> Union[GenerateSampleOutput, torch.LongTensor]:
87
"""Sampling-based generation."""
88
89
def greedy_search(
90
self,
91
input_ids: torch.LongTensor,
92
logits_processor: Optional[LogitsProcessorList] = None,
93
stopping_criteria: Optional[StoppingCriteriaList] = None,
94
**kwargs
95
) -> Union[GenerateGreedyOutput, torch.LongTensor]:
96
"""Greedy decoding."""
97
98
def contrastive_search(
99
self,
100
input_ids: torch.LongTensor,
101
penalty_alpha: float,
102
top_k: int,
103
logits_processor: Optional[LogitsProcessorList] = None,
104
stopping_criteria: Optional[StoppingCriteriaList] = None,
105
**kwargs
106
) -> Union[GenerateContrastiveOutput, torch.LongTensor]:
107
"""Contrastive search decoding."""
108
```
109
110
### Generation Configuration
111
112
Comprehensive configuration for generation parameters and strategies.
113
114
```python { .api }
115
class GenerationConfig:
116
def __init__(
117
self,
118
# Length parameters
119
max_length: int = 20,
120
max_new_tokens: Optional[int] = None,
121
min_length: int = 0,
122
min_new_tokens: Optional[int] = None,
123
early_stopping: Union[bool, str] = False,
124
max_time: Optional[float] = None,
125
126
# Generation strategy
127
do_sample: bool = False,
128
num_beams: int = 1,
129
num_beam_groups: int = 1,
130
penalty_alpha: Optional[float] = None,
131
use_cache: bool = True,
132
133
# Sampling parameters
134
temperature: float = 1.0,
135
top_k: int = 50,
136
top_p: float = 1.0,
137
typical_p: float = 1.0,
138
epsilon_cutoff: float = 0.0,
139
eta_cutoff: float = 0.0,
140
diversity_penalty: float = 0.0,
141
142
# Repetition parameters
143
repetition_penalty: float = 1.0,
144
no_repeat_ngram_size: int = 0,
145
encoder_no_repeat_ngram_size: int = 0,
146
147
# Special tokens
148
bos_token_id: Optional[int] = None,
149
pad_token_id: Optional[int] = None,
150
eos_token_id: Optional[Union[int, List[int]]] = None,
151
decoder_start_token_id: Optional[int] = None,
152
153
# Generation control
154
num_return_sequences: int = 1,
155
output_attentions: bool = False,
156
output_hidden_states: bool = False,
157
output_scores: bool = False,
158
return_dict_in_generate: bool = False,
159
forced_bos_token_id: Optional[int] = None,
160
forced_eos_token_id: Optional[Union[int, List[int]]] = None,
161
remove_invalid_values: bool = False,
162
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
163
suppress_tokens: Optional[List[int]] = None,
164
begin_suppress_tokens: Optional[List[int]] = None,
165
forced_decoder_ids: Optional[List[List[int]]] = None,
166
167
# Sequence bias
168
sequence_bias: Optional[Dict[Tuple[int], float]] = None,
169
guidance_scale: Optional[float] = None,
170
low_memory: Optional[bool] = None,
171
172
# Watermarking
173
watermarking_config: Optional[Dict] = None,
174
175
**kwargs
176
):
177
"""
178
Configuration for text generation.
179
180
Key parameters:
181
max_length: Maximum total sequence length
182
max_new_tokens: Maximum number of new tokens to generate
183
min_length: Minimum sequence length
184
do_sample: Use sampling instead of greedy/beam search
185
num_beams: Number of beams for beam search
186
temperature: Sampling temperature (higher = more random)
187
top_k: Keep only top-k tokens for sampling
188
top_p: Nucleus sampling probability threshold
189
repetition_penalty: Penalty for repeated tokens
190
no_repeat_ngram_size: Prevent repeating n-grams
191
num_return_sequences: Number of sequences to generate
192
"""
193
194
@classmethod
195
def from_pretrained(
196
cls,
197
pretrained_model_name: str,
198
config_file_name: Optional[str] = None,
199
cache_dir: Optional[str] = None,
200
force_download: bool = False,
201
**kwargs
202
) -> "GenerationConfig":
203
"""Load generation config from pretrained model."""
204
205
def save_pretrained(
206
self,
207
save_directory: Union[str, os.PathLike],
208
config_file_name: Optional[str] = None,
209
push_to_hub: bool = False,
210
**kwargs
211
) -> None:
212
"""Save generation config to directory."""
213
214
def update(self, **kwargs) -> None:
215
"""Update configuration with new parameters."""
216
```
217
218
### Beam Search Scoring
219
220
Advanced beam search with scoring and ranking capabilities.
221
222
```python { .api }
223
class BeamScorer:
224
"""Base class for beam search scoring."""
225
226
def process(
227
self,
228
input_ids: torch.LongTensor,
229
next_scores: torch.FloatTensor,
230
next_tokens: torch.LongTensor,
231
next_indices: torch.LongTensor,
232
**kwargs
233
) -> Tuple[torch.Tensor]:
234
"""Process beam candidates."""
235
236
def finalize(
237
self,
238
input_ids: torch.LongTensor,
239
final_beam_scores: torch.FloatTensor,
240
final_beam_tokens: torch.LongTensor,
241
final_beam_indices: torch.LongTensor,
242
**kwargs
243
) -> torch.LongTensor:
244
"""Finalize beam search."""
245
246
class BeamSearchScorer(BeamScorer):
247
def __init__(
248
self,
249
batch_size: int,
250
num_beams: int,
251
device: torch.device,
252
length_penalty: Optional[float] = 1.0,
253
do_early_stopping: Optional[bool] = False,
254
num_beam_hyps_to_keep: Optional[int] = 1,
255
num_beam_groups: Optional[int] = 1,
256
**kwargs
257
):
258
"""
259
Beam search scorer with length penalty and early stopping.
260
261
Args:
262
batch_size: Batch size
263
num_beams: Number of beams
264
device: Device to run on
265
length_penalty: Length penalty for beam scoring
266
do_early_stopping: Stop when finding complete sequences
267
num_beam_hyps_to_keep: Number of hypotheses to keep
268
num_beam_groups: Number of beam groups for diverse search
269
"""
270
271
class ConstrainedBeamSearchScorer(BeamScorer):
272
def __init__(
273
self,
274
batch_size: int,
275
num_beams: int,
276
device: torch.device,
277
constraints: List[Constraint],
278
**kwargs
279
):
280
"""Beam search with lexical constraints."""
281
```
282
283
### Logits Processing
284
285
Customizable logits processing for generation control.
286
287
```python { .api }
288
class LogitsProcessor:
289
"""Base class for logits processors."""
290
291
def __call__(
292
self,
293
input_ids: torch.LongTensor,
294
scores: torch.FloatTensor
295
) -> torch.FloatTensor:
296
"""Process logits before sampling/selection."""
297
298
class LogitsProcessorList(List[LogitsProcessor]):
299
"""List of logits processors applied sequentially."""
300
301
class TemperatureLogitsWarper(LogitsProcessor):
302
def __init__(self, temperature: float):
303
"""Apply temperature scaling to logits."""
304
305
class TopKLogitsWarper(LogitsProcessor):
306
def __init__(
307
self,
308
top_k: int,
309
filter_value: float = float("-inf"),
310
min_tokens_to_keep: int = 1
311
):
312
"""Keep only top-k tokens, set others to filter_value."""
313
314
class TopPLogitsWarper(LogitsProcessor):
315
def __init__(
316
self,
317
top_p: float,
318
filter_value: float = float("-inf"),
319
min_tokens_to_keep: int = 1
320
):
321
"""Nucleus sampling: keep tokens with cumulative probability <= top_p."""
322
323
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
324
def __init__(self, penalty: float):
325
"""Apply repetition penalty to previously generated tokens."""
326
327
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
328
def __init__(self, ngram_size: int):
329
"""Prevent repeating n-grams."""
330
```
331
332
### Stopping Criteria
333
334
Flexible stopping conditions for generation.
335
336
```python { .api }
337
class StoppingCriteria:
338
"""Base class for stopping criteria."""
339
340
def __call__(
341
self,
342
input_ids: torch.LongTensor,
343
scores: torch.FloatTensor,
344
**kwargs
345
) -> bool:
346
"""Check if generation should stop."""
347
348
class StoppingCriteriaList(List[StoppingCriteria]):
349
"""List of stopping criteria (OR logic)."""
350
351
class MaxLengthCriteria(StoppingCriteria):
352
def __init__(self, max_length: int):
353
"""Stop when reaching maximum length."""
354
355
class MaxTimeCriteria(StoppingCriteria):
356
def __init__(self, max_time: float):
357
"""Stop when exceeding maximum time."""
358
359
class KeywordsStoppingCriteria(StoppingCriteria):
360
def __init__(
361
self,
362
keywords: List[str],
363
tokenizer: PreTrainedTokenizer
364
):
365
"""Stop when generating specific keywords."""
366
```
367
368
### Generation Output Types
369
370
Structured outputs from different generation methods.
371
372
```python { .api }
373
class GenerateOutput:
374
"""Base output type for generation."""
375
sequences: torch.LongTensor
376
scores: Optional[Tuple[torch.FloatTensor]] = None
377
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
378
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
379
380
class GenerateBeamOutput(GenerateOutput):
381
"""Output from beam search generation."""
382
sequences_scores: Optional[torch.FloatTensor] = None
383
beam_indices: Optional[torch.LongTensor] = None
384
385
class GenerateSampleOutput(GenerateOutput):
386
"""Output from sampling generation."""
387
388
class GenerateGreedyOutput(GenerateOutput):
389
"""Output from greedy generation."""
390
```
391
392
### Streaming Generation
393
394
Real-time streaming of generated text.
395
396
```python { .api }
397
class BaseStreamer:
398
"""Base class for generation streamers."""
399
400
def put(self, value: torch.LongTensor) -> None:
401
"""Process new generated tokens."""
402
403
def end(self) -> None:
404
"""Signal end of generation."""
405
406
class TextStreamer(BaseStreamer):
407
def __init__(
408
self,
409
tokenizer: PreTrainedTokenizer,
410
skip_prompt: bool = False,
411
skip_special_tokens: bool = False,
412
**decode_kwargs
413
):
414
"""
415
Stream generated text to stdout.
416
417
Args:
418
tokenizer: Tokenizer for decoding
419
skip_prompt: Skip printing the input prompt
420
skip_special_tokens: Skip special tokens in output
421
**decode_kwargs: Arguments for tokenizer.decode()
422
"""
423
424
class TextIteratorStreamer(BaseStreamer):
425
def __init__(
426
self,
427
tokenizer: PreTrainedTokenizer,
428
skip_prompt: bool = False,
429
timeout: Optional[float] = None,
430
**decode_kwargs
431
):
432
"""
433
Stream generated text through iterator interface.
434
435
Args:
436
tokenizer: Tokenizer for decoding
437
skip_prompt: Skip the input prompt
438
timeout: Timeout for iteration
439
**decode_kwargs: Arguments for tokenizer.decode()
440
"""
441
442
def __iter__(self) -> Iterator[str]:
443
"""Iterate over generated text chunks."""
444
```
445
446
## Generation Examples
447
448
Common generation patterns and use cases:
449
450
```python
451
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
452
453
# Load model and tokenizer
454
model = AutoModelForCausalLM.from_pretrained("gpt2")
455
tokenizer = AutoTokenizer.from_pretrained("gpt2")
456
tokenizer.pad_token = tokenizer.eos_token
457
458
# Basic generation
459
prompt = "The future of artificial intelligence is"
460
inputs = tokenizer(prompt, return_tensors="pt")
461
outputs = model.generate(**inputs, max_new_tokens=50)
462
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
463
464
# Sampling with temperature
465
outputs = model.generate(
466
**inputs,
467
max_new_tokens=50,
468
do_sample=True,
469
temperature=0.8,
470
top_k=50,
471
top_p=0.9
472
)
473
474
# Beam search
475
outputs = model.generate(
476
**inputs,
477
max_new_tokens=50,
478
num_beams=5,
479
early_stopping=True
480
)
481
482
# Multiple sequences
483
outputs = model.generate(
484
**inputs,
485
max_new_tokens=50,
486
num_return_sequences=3,
487
do_sample=True,
488
temperature=0.8
489
)
490
491
# With custom generation config
492
gen_config = GenerationConfig(
493
max_new_tokens=100,
494
do_sample=True,
495
temperature=0.7,
496
top_p=0.9,
497
repetition_penalty=1.1,
498
no_repeat_ngram_size=2
499
)
500
501
outputs = model.generate(**inputs, generation_config=gen_config)
502
503
# Streaming generation
504
from transformers import TextStreamer
505
streamer = TextStreamer(tokenizer, skip_prompt=True)
506
507
outputs = model.generate(
508
**inputs,
509
max_new_tokens=50,
510
streamer=streamer
511
)
512
513
# Constrained generation
514
from transformers import KeywordsStoppingCriteria
515
stop_words = ["END", "STOP"]
516
stopping_criteria = KeywordsStoppingCriteria(stop_words, tokenizer)
517
518
outputs = model.generate(
519
**inputs,
520
max_new_tokens=50,
521
stopping_criteria=[stopping_criteria]
522
)
523
```