or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

feature-extraction.mdgeneration.mdindex.mdmodels.mdoptimization.mdpipelines.mdtokenization.mdtraining.md

generation.mddocs/

0

# Generation

1

2

Advanced text generation capabilities with multiple decoding strategies, fine-grained control over output, and support for conversational AI. The generation system provides flexible interfaces for autoregressive text generation with extensive customization options.

3

4

## Capabilities

5

6

### Generation Mixin

7

8

Core generation functionality available on all generative models.

9

10

```python { .api }

11

class GenerationMixin:

12

def generate(

13

self,

14

inputs: Optional[torch.Tensor] = None,

15

generation_config: Optional[GenerationConfig] = None,

16

logits_processor: Optional[LogitsProcessorList] = None,

17

stopping_criteria: Optional[StoppingCriteriaList] = None,

18

prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,

19

synced_gpus: Optional[bool] = None,

20

assistant_model: Optional["PreTrainedModel"] = None,

21

streamer: Optional["BaseStreamer"] = None,

22

negative_prompt_ids: Optional[torch.Tensor] = None,

23

negative_prompt_attention_mask: Optional[torch.Tensor] = None,

24

use_model_defaults: Optional[bool] = None,

25

custom_generate: Optional[Union[str, Callable]] = None,

26

**kwargs

27

) -> Union[GenerateOutput, torch.LongTensor]:

28

"""

29

Generate sequences using the model.

30

31

Args:

32

inputs: Input token IDs

33

generation_config: Generation configuration

34

logits_processor: Custom logits processors

35

stopping_criteria: Custom stopping criteria

36

prefix_allowed_tokens_fn: Constrain generation to allowed tokens

37

synced_gpus: Synchronize GPUs in distributed setting

38

assistant_model: Assistant model for speculative decoding

39

streamer: Streamer for real-time generation output

40

negative_prompt_ids: Negative prompt for guidance

41

negative_prompt_attention_mask: Attention mask for negative prompt

42

use_model_defaults: Use model's default generation config

43

custom_generate: Custom generation function or string identifier

44

**kwargs: Additional generation parameters

45

46

Returns:

47

Generated token sequences

48

"""

49

50

def beam_search(

51

self,

52

input_ids: torch.LongTensor,

53

beam_scorer: BeamScorer,

54

logits_processor: Optional[LogitsProcessorList] = None,

55

stopping_criteria: Optional[StoppingCriteriaList] = None,

56

**kwargs

57

) -> Union[GenerateBeamOutput, torch.LongTensor]:

58

"""Beam search decoding."""

59

60

def beam_sample(

61

self,

62

input_ids: torch.LongTensor,

63

beam_scorer: BeamScorer,

64

logits_processor: Optional[LogitsProcessorList] = None,

65

stopping_criteria: Optional[StoppingCriteriaList] = None,

66

**kwargs

67

) -> Union[GenerateBeamOutput, torch.LongTensor]:

68

"""Beam search with sampling."""

69

70

def group_beam_search(

71

self,

72

input_ids: torch.LongTensor,

73

beam_scorer: BeamScorer,

74

logits_processor: Optional[LogitsProcessorList] = None,

75

stopping_criteria: Optional[StoppingCriteriaList] = None,

76

**kwargs

77

) -> Union[GenerateBeamOutput, torch.LongTensor]:

78

"""Diverse beam search with groups."""

79

80

def sample(

81

self,

82

input_ids: torch.LongTensor,

83

logits_processor: Optional[LogitsProcessorList] = None,

84

stopping_criteria: Optional[StoppingCriteriaList] = None,

85

**kwargs

86

) -> Union[GenerateSampleOutput, torch.LongTensor]:

87

"""Sampling-based generation."""

88

89

def greedy_search(

90

self,

91

input_ids: torch.LongTensor,

92

logits_processor: Optional[LogitsProcessorList] = None,

93

stopping_criteria: Optional[StoppingCriteriaList] = None,

94

**kwargs

95

) -> Union[GenerateGreedyOutput, torch.LongTensor]:

96

"""Greedy decoding."""

97

98

def contrastive_search(

99

self,

100

input_ids: torch.LongTensor,

101

penalty_alpha: float,

102

top_k: int,

103

logits_processor: Optional[LogitsProcessorList] = None,

104

stopping_criteria: Optional[StoppingCriteriaList] = None,

105

**kwargs

106

) -> Union[GenerateContrastiveOutput, torch.LongTensor]:

107

"""Contrastive search decoding."""

108

```

109

110

### Generation Configuration

111

112

Comprehensive configuration for generation parameters and strategies.

113

114

```python { .api }

115

class GenerationConfig:

116

def __init__(

117

self,

118

# Length parameters

119

max_length: int = 20,

120

max_new_tokens: Optional[int] = None,

121

min_length: int = 0,

122

min_new_tokens: Optional[int] = None,

123

early_stopping: Union[bool, str] = False,

124

max_time: Optional[float] = None,

125

126

# Generation strategy

127

do_sample: bool = False,

128

num_beams: int = 1,

129

num_beam_groups: int = 1,

130

penalty_alpha: Optional[float] = None,

131

use_cache: bool = True,

132

133

# Sampling parameters

134

temperature: float = 1.0,

135

top_k: int = 50,

136

top_p: float = 1.0,

137

typical_p: float = 1.0,

138

epsilon_cutoff: float = 0.0,

139

eta_cutoff: float = 0.0,

140

diversity_penalty: float = 0.0,

141

142

# Repetition parameters

143

repetition_penalty: float = 1.0,

144

no_repeat_ngram_size: int = 0,

145

encoder_no_repeat_ngram_size: int = 0,

146

147

# Special tokens

148

bos_token_id: Optional[int] = None,

149

pad_token_id: Optional[int] = None,

150

eos_token_id: Optional[Union[int, List[int]]] = None,

151

decoder_start_token_id: Optional[int] = None,

152

153

# Generation control

154

num_return_sequences: int = 1,

155

output_attentions: bool = False,

156

output_hidden_states: bool = False,

157

output_scores: bool = False,

158

return_dict_in_generate: bool = False,

159

forced_bos_token_id: Optional[int] = None,

160

forced_eos_token_id: Optional[Union[int, List[int]]] = None,

161

remove_invalid_values: bool = False,

162

exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,

163

suppress_tokens: Optional[List[int]] = None,

164

begin_suppress_tokens: Optional[List[int]] = None,

165

forced_decoder_ids: Optional[List[List[int]]] = None,

166

167

# Sequence bias

168

sequence_bias: Optional[Dict[Tuple[int], float]] = None,

169

guidance_scale: Optional[float] = None,

170

low_memory: Optional[bool] = None,

171

172

# Watermarking

173

watermarking_config: Optional[Dict] = None,

174

175

**kwargs

176

):

177

"""

178

Configuration for text generation.

179

180

Key parameters:

181

max_length: Maximum total sequence length

182

max_new_tokens: Maximum number of new tokens to generate

183

min_length: Minimum sequence length

184

do_sample: Use sampling instead of greedy/beam search

185

num_beams: Number of beams for beam search

186

temperature: Sampling temperature (higher = more random)

187

top_k: Keep only top-k tokens for sampling

188

top_p: Nucleus sampling probability threshold

189

repetition_penalty: Penalty for repeated tokens

190

no_repeat_ngram_size: Prevent repeating n-grams

191

num_return_sequences: Number of sequences to generate

192

"""

193

194

@classmethod

195

def from_pretrained(

196

cls,

197

pretrained_model_name: str,

198

config_file_name: Optional[str] = None,

199

cache_dir: Optional[str] = None,

200

force_download: bool = False,

201

**kwargs

202

) -> "GenerationConfig":

203

"""Load generation config from pretrained model."""

204

205

def save_pretrained(

206

self,

207

save_directory: Union[str, os.PathLike],

208

config_file_name: Optional[str] = None,

209

push_to_hub: bool = False,

210

**kwargs

211

) -> None:

212

"""Save generation config to directory."""

213

214

def update(self, **kwargs) -> None:

215

"""Update configuration with new parameters."""

216

```

217

218

### Beam Search Scoring

219

220

Advanced beam search with scoring and ranking capabilities.

221

222

```python { .api }

223

class BeamScorer:

224

"""Base class for beam search scoring."""

225

226

def process(

227

self,

228

input_ids: torch.LongTensor,

229

next_scores: torch.FloatTensor,

230

next_tokens: torch.LongTensor,

231

next_indices: torch.LongTensor,

232

**kwargs

233

) -> Tuple[torch.Tensor]:

234

"""Process beam candidates."""

235

236

def finalize(

237

self,

238

input_ids: torch.LongTensor,

239

final_beam_scores: torch.FloatTensor,

240

final_beam_tokens: torch.LongTensor,

241

final_beam_indices: torch.LongTensor,

242

**kwargs

243

) -> torch.LongTensor:

244

"""Finalize beam search."""

245

246

class BeamSearchScorer(BeamScorer):

247

def __init__(

248

self,

249

batch_size: int,

250

num_beams: int,

251

device: torch.device,

252

length_penalty: Optional[float] = 1.0,

253

do_early_stopping: Optional[bool] = False,

254

num_beam_hyps_to_keep: Optional[int] = 1,

255

num_beam_groups: Optional[int] = 1,

256

**kwargs

257

):

258

"""

259

Beam search scorer with length penalty and early stopping.

260

261

Args:

262

batch_size: Batch size

263

num_beams: Number of beams

264

device: Device to run on

265

length_penalty: Length penalty for beam scoring

266

do_early_stopping: Stop when finding complete sequences

267

num_beam_hyps_to_keep: Number of hypotheses to keep

268

num_beam_groups: Number of beam groups for diverse search

269

"""

270

271

class ConstrainedBeamSearchScorer(BeamScorer):

272

def __init__(

273

self,

274

batch_size: int,

275

num_beams: int,

276

device: torch.device,

277

constraints: List[Constraint],

278

**kwargs

279

):

280

"""Beam search with lexical constraints."""

281

```

282

283

### Logits Processing

284

285

Customizable logits processing for generation control.

286

287

```python { .api }

288

class LogitsProcessor:

289

"""Base class for logits processors."""

290

291

def __call__(

292

self,

293

input_ids: torch.LongTensor,

294

scores: torch.FloatTensor

295

) -> torch.FloatTensor:

296

"""Process logits before sampling/selection."""

297

298

class LogitsProcessorList(List[LogitsProcessor]):

299

"""List of logits processors applied sequentially."""

300

301

class TemperatureLogitsWarper(LogitsProcessor):

302

def __init__(self, temperature: float):

303

"""Apply temperature scaling to logits."""

304

305

class TopKLogitsWarper(LogitsProcessor):

306

def __init__(

307

self,

308

top_k: int,

309

filter_value: float = float("-inf"),

310

min_tokens_to_keep: int = 1

311

):

312

"""Keep only top-k tokens, set others to filter_value."""

313

314

class TopPLogitsWarper(LogitsProcessor):

315

def __init__(

316

self,

317

top_p: float,

318

filter_value: float = float("-inf"),

319

min_tokens_to_keep: int = 1

320

):

321

"""Nucleus sampling: keep tokens with cumulative probability <= top_p."""

322

323

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):

324

def __init__(self, penalty: float):

325

"""Apply repetition penalty to previously generated tokens."""

326

327

class NoRepeatNGramLogitsProcessor(LogitsProcessor):

328

def __init__(self, ngram_size: int):

329

"""Prevent repeating n-grams."""

330

```

331

332

### Stopping Criteria

333

334

Flexible stopping conditions for generation.

335

336

```python { .api }

337

class StoppingCriteria:

338

"""Base class for stopping criteria."""

339

340

def __call__(

341

self,

342

input_ids: torch.LongTensor,

343

scores: torch.FloatTensor,

344

**kwargs

345

) -> bool:

346

"""Check if generation should stop."""

347

348

class StoppingCriteriaList(List[StoppingCriteria]):

349

"""List of stopping criteria (OR logic)."""

350

351

class MaxLengthCriteria(StoppingCriteria):

352

def __init__(self, max_length: int):

353

"""Stop when reaching maximum length."""

354

355

class MaxTimeCriteria(StoppingCriteria):

356

def __init__(self, max_time: float):

357

"""Stop when exceeding maximum time."""

358

359

class KeywordsStoppingCriteria(StoppingCriteria):

360

def __init__(

361

self,

362

keywords: List[str],

363

tokenizer: PreTrainedTokenizer

364

):

365

"""Stop when generating specific keywords."""

366

```

367

368

### Generation Output Types

369

370

Structured outputs from different generation methods.

371

372

```python { .api }

373

class GenerateOutput:

374

"""Base output type for generation."""

375

sequences: torch.LongTensor

376

scores: Optional[Tuple[torch.FloatTensor]] = None

377

attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

378

hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

379

380

class GenerateBeamOutput(GenerateOutput):

381

"""Output from beam search generation."""

382

sequences_scores: Optional[torch.FloatTensor] = None

383

beam_indices: Optional[torch.LongTensor] = None

384

385

class GenerateSampleOutput(GenerateOutput):

386

"""Output from sampling generation."""

387

388

class GenerateGreedyOutput(GenerateOutput):

389

"""Output from greedy generation."""

390

```

391

392

### Streaming Generation

393

394

Real-time streaming of generated text.

395

396

```python { .api }

397

class BaseStreamer:

398

"""Base class for generation streamers."""

399

400

def put(self, value: torch.LongTensor) -> None:

401

"""Process new generated tokens."""

402

403

def end(self) -> None:

404

"""Signal end of generation."""

405

406

class TextStreamer(BaseStreamer):

407

def __init__(

408

self,

409

tokenizer: PreTrainedTokenizer,

410

skip_prompt: bool = False,

411

skip_special_tokens: bool = False,

412

**decode_kwargs

413

):

414

"""

415

Stream generated text to stdout.

416

417

Args:

418

tokenizer: Tokenizer for decoding

419

skip_prompt: Skip printing the input prompt

420

skip_special_tokens: Skip special tokens in output

421

**decode_kwargs: Arguments for tokenizer.decode()

422

"""

423

424

class TextIteratorStreamer(BaseStreamer):

425

def __init__(

426

self,

427

tokenizer: PreTrainedTokenizer,

428

skip_prompt: bool = False,

429

timeout: Optional[float] = None,

430

**decode_kwargs

431

):

432

"""

433

Stream generated text through iterator interface.

434

435

Args:

436

tokenizer: Tokenizer for decoding

437

skip_prompt: Skip the input prompt

438

timeout: Timeout for iteration

439

**decode_kwargs: Arguments for tokenizer.decode()

440

"""

441

442

def __iter__(self) -> Iterator[str]:

443

"""Iterate over generated text chunks."""

444

```

445

446

## Generation Examples

447

448

Common generation patterns and use cases:

449

450

```python

451

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

452

453

# Load model and tokenizer

454

model = AutoModelForCausalLM.from_pretrained("gpt2")

455

tokenizer = AutoTokenizer.from_pretrained("gpt2")

456

tokenizer.pad_token = tokenizer.eos_token

457

458

# Basic generation

459

prompt = "The future of artificial intelligence is"

460

inputs = tokenizer(prompt, return_tensors="pt")

461

outputs = model.generate(**inputs, max_new_tokens=50)

462

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

463

464

# Sampling with temperature

465

outputs = model.generate(

466

**inputs,

467

max_new_tokens=50,

468

do_sample=True,

469

temperature=0.8,

470

top_k=50,

471

top_p=0.9

472

)

473

474

# Beam search

475

outputs = model.generate(

476

**inputs,

477

max_new_tokens=50,

478

num_beams=5,

479

early_stopping=True

480

)

481

482

# Multiple sequences

483

outputs = model.generate(

484

**inputs,

485

max_new_tokens=50,

486

num_return_sequences=3,

487

do_sample=True,

488

temperature=0.8

489

)

490

491

# With custom generation config

492

gen_config = GenerationConfig(

493

max_new_tokens=100,

494

do_sample=True,

495

temperature=0.7,

496

top_p=0.9,

497

repetition_penalty=1.1,

498

no_repeat_ngram_size=2

499

)

500

501

outputs = model.generate(**inputs, generation_config=gen_config)

502

503

# Streaming generation

504

from transformers import TextStreamer

505

streamer = TextStreamer(tokenizer, skip_prompt=True)

506

507

outputs = model.generate(

508

**inputs,

509

max_new_tokens=50,

510

streamer=streamer

511

)

512

513

# Constrained generation

514

from transformers import KeywordsStoppingCriteria

515

stop_words = ["END", "STOP"]

516

stopping_criteria = KeywordsStoppingCriteria(stop_words, tokenizer)

517

518

outputs = model.generate(

519

**inputs,

520

max_new_tokens=50,

521

stopping_criteria=[stopping_criteria]

522

)

523

```