0
# Text Generation Sampling
1
2
Sampling strategies for controlling text generation behavior in language models. Keras Hub provides various sampling methods to balance between quality, diversity, and controllability in generated text.
3
4
## Capabilities
5
6
### Base Classes
7
8
Foundation classes for text generation sampling.
9
10
```python { .api }
11
class Sampler:
12
"""Base class for all samplers."""
13
def __init__(self, **kwargs): ...
14
15
def __call__(
16
self,
17
next_token_logits,
18
prompt_tokens,
19
generated_tokens,
20
**kwargs
21
): ...
22
23
def get_next_token(self, probabilities): ...
24
```
25
26
### Deterministic Sampling
27
28
Samplers that produce deterministic outputs given the same input.
29
30
```python { .api }
31
class GreedySampler(Sampler):
32
"""
33
Greedy sampling always selects the token with highest probability.
34
Produces deterministic but potentially repetitive outputs.
35
"""
36
def __init__(self, **kwargs): ...
37
38
class BeamSampler(Sampler):
39
"""
40
Beam search maintains multiple candidate sequences and selects
41
the sequence with highest overall probability.
42
"""
43
def __init__(
44
self,
45
num_beams: int = 5,
46
return_all_beams: bool = False,
47
**kwargs
48
): ...
49
```
50
51
### Stochastic Sampling
52
53
Samplers that introduce randomness for more diverse outputs.
54
55
```python { .api }
56
class RandomSampler(Sampler):
57
"""
58
Random sampling selects tokens according to their probability distribution.
59
Higher temperature increases randomness.
60
"""
61
def __init__(
62
self,
63
temperature: float = 1.0,
64
seed: int = None,
65
**kwargs
66
): ...
67
68
class TopKSampler(Sampler):
69
"""
70
Top-k sampling considers only the k most likely tokens at each step.
71
Balances quality and diversity by filtering low-probability tokens.
72
"""
73
def __init__(
74
self,
75
k: int = 50,
76
temperature: float = 1.0,
77
seed: int = None,
78
**kwargs
79
): ...
80
81
class TopPSampler(Sampler):
82
"""
83
Top-p (nucleus) sampling considers tokens whose cumulative probability
84
is within the top p fraction. Adapts the number of considered tokens
85
based on the probability distribution.
86
"""
87
def __init__(
88
self,
89
p: float = 0.9,
90
temperature: float = 1.0,
91
seed: int = None,
92
**kwargs
93
): ...
94
```
95
96
### Advanced Sampling
97
98
More sophisticated sampling strategies for improved generation quality.
99
100
```python { .api }
101
class ContrastiveSampler(Sampler):
102
"""
103
Contrastive search balances high probability and low repetition
104
by penalizing tokens that are too similar to previously generated tokens.
105
"""
106
def __init__(
107
self,
108
k: int = 4,
109
alpha: float = 0.6,
110
**kwargs
111
): ...
112
```
113
114
### Sampler Utilities
115
116
Utilities for working with samplers programmatically.
117
118
```python { .api }
119
def serialize(sampler: Sampler) -> dict:
120
"""
121
Serialize a sampler instance to a dictionary.
122
123
Args:
124
sampler: The sampler instance to serialize
125
126
Returns:
127
Dictionary representation of the sampler
128
"""
129
...
130
131
def deserialize(config: dict) -> Sampler:
132
"""
133
Deserialize a sampler from a dictionary configuration.
134
135
Args:
136
config: Dictionary configuration of the sampler
137
138
Returns:
139
Sampler instance
140
"""
141
...
142
143
def get(identifier) -> Sampler:
144
"""
145
Get a sampler by name or return existing sampler instance.
146
147
Args:
148
identifier: String name or sampler instance
149
150
Returns:
151
Sampler instance
152
"""
153
...
154
```
155
156
## Usage Examples
157
158
### Greedy Sampling for Deterministic Output
159
160
```python
161
import keras_hub
162
163
# Load model
164
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
165
166
# Create greedy sampler
167
sampler = keras_hub.samplers.GreedySampler()
168
169
# Generate text deterministically
170
prompt = "The future of artificial intelligence"
171
output = model.generate(prompt, max_length=50, sampler=sampler)
172
print("Greedy output:", output)
173
```
174
175
### Random Sampling with Temperature Control
176
177
```python
178
import keras_hub
179
180
# Load model
181
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
182
183
# Low temperature for more focused generation
184
low_temp_sampler = keras_hub.samplers.RandomSampler(temperature=0.3)
185
output_focused = model.generate(
186
"The weather today is",
187
max_length=30,
188
sampler=low_temp_sampler
189
)
190
191
# High temperature for more creative generation
192
high_temp_sampler = keras_hub.samplers.RandomSampler(temperature=1.5)
193
output_creative = model.generate(
194
"The weather today is",
195
max_length=30,
196
sampler=high_temp_sampler
197
)
198
199
print("Focused output:", output_focused)
200
print("Creative output:", output_creative)
201
```
202
203
### Top-k Sampling for Quality-Diversity Balance
204
205
```python
206
import keras_hub
207
208
# Load model
209
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
210
211
# Top-k sampling with different k values
212
small_k_sampler = keras_hub.samplers.TopKSampler(k=10, temperature=0.8)
213
large_k_sampler = keras_hub.samplers.TopKSampler(k=100, temperature=0.8)
214
215
prompt = "In the distant future"
216
217
# More conservative generation (smaller k)
218
output_conservative = model.generate(prompt, max_length=40, sampler=small_k_sampler)
219
220
# More diverse generation (larger k)
221
output_diverse = model.generate(prompt, max_length=40, sampler=large_k_sampler)
222
223
print("Conservative (k=10):", output_conservative)
224
print("Diverse (k=100):", output_diverse)
225
```
226
227
### Top-p (Nucleus) Sampling
228
229
```python
230
import keras_hub
231
232
# Load model
233
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
234
235
# Top-p sampling adapts to probability distribution
236
sampler = keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8)
237
238
# Generate multiple outputs to see diversity
239
prompt = "Once upon a time"
240
for i in range(3):
241
output = model.generate(prompt, max_length=25, sampler=sampler)
242
print(f"Output {i+1}: {output}")
243
```
244
245
### Beam Search for Best Overall Sequence
246
247
```python
248
import keras_hub
249
250
# Load model
251
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
252
253
# Beam search with different beam sizes
254
beam_sampler = keras_hub.samplers.BeamSampler(
255
num_beams=5,
256
return_all_beams=False # Return only best beam
257
)
258
259
prompt = "The most important discovery in science"
260
output = model.generate(prompt, max_length=35, sampler=beam_sampler)
261
print("Beam search output:", output)
262
263
# Return all beams to see alternatives
264
all_beams_sampler = keras_hub.samplers.BeamSampler(
265
num_beams=3,
266
return_all_beams=True
267
)
268
269
all_outputs = model.generate(prompt, max_length=25, sampler=all_beams_sampler)
270
for i, beam_output in enumerate(all_outputs):
271
print(f"Beam {i+1}: {beam_output}")
272
```
273
274
### Contrastive Search for Reducing Repetition
275
276
```python
277
import keras_hub
278
279
# Load model
280
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
281
282
# Contrastive search balances probability and novelty
283
sampler = keras_hub.samplers.ContrastiveSampler(
284
k=4, # Number of top tokens to consider
285
alpha=0.6 # Balance between probability and novelty
286
)
287
288
prompt = "Artificial intelligence will change the world by"
289
output = model.generate(prompt, max_length=50, sampler=sampler)
290
print("Contrastive search output:", output)
291
```
292
293
### Comparing Different Sampling Methods
294
295
```python
296
import keras_hub
297
298
# Load model
299
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
300
301
# Define different samplers
302
samplers = {
303
"Greedy": keras_hub.samplers.GreedySampler(),
304
"Random (T=0.8)": keras_hub.samplers.RandomSampler(temperature=0.8),
305
"Top-k (k=50)": keras_hub.samplers.TopKSampler(k=50, temperature=0.8),
306
"Top-p (p=0.9)": keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8),
307
"Contrastive": keras_hub.samplers.ContrastiveSampler(k=4, alpha=0.6)
308
}
309
310
prompt = "The key to happiness is"
311
312
# Generate with each sampler
313
for name, sampler in samplers.items():
314
output = model.generate(prompt, max_length=30, sampler=sampler)
315
print(f"{name}: {output}")
316
```
317
318
### Serializing and Deserializing Samplers
319
320
```python
321
import keras_hub
322
323
# Create a sampler
324
original_sampler = keras_hub.samplers.TopKSampler(k=40, temperature=0.7)
325
326
# Serialize to dictionary
327
config = keras_hub.samplers.serialize(original_sampler)
328
print("Serialized config:", config)
329
330
# Deserialize back to sampler
331
restored_sampler = keras_hub.samplers.deserialize(config)
332
333
# Use restored sampler
334
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
335
output = model.generate("Hello world", max_length=20, sampler=restored_sampler)
336
print("Generated with restored sampler:", output)
337
```
338
339
### Getting Samplers by Name
340
341
```python
342
import keras_hub
343
344
# Get sampler by string identifier
345
greedy = keras_hub.samplers.get("greedy")
346
random = keras_hub.samplers.get("random")
347
348
# Get existing sampler instance (returns same instance)
349
top_k = keras_hub.samplers.TopKSampler(k=50)
350
same_sampler = keras_hub.samplers.get(top_k)
351
352
print("Greedy sampler:", type(greedy).__name__)
353
print("Random sampler:", type(random).__name__)
354
print("Same instance:", top_k is same_sampler)
355
```
356
357
### Custom Sampling with Manual Control
358
359
```python
360
import keras_hub
361
import numpy as np
362
363
# Load model and get logits manually
364
model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
365
366
# Get next token logits for a prompt
367
prompt_tokens = model.preprocessor.tokenizer(["Hello world"])
368
logits = model.backbone(prompt_tokens)[:, -1, :] # Last token logits
369
370
# Apply different samplers to the same logits
371
samplers = [
372
keras_hub.samplers.GreedySampler(),
373
keras_hub.samplers.TopKSampler(k=10),
374
keras_hub.samplers.TopPSampler(p=0.8)
375
]
376
377
for sampler in samplers:
378
# Sample next token
379
next_token = sampler(logits, prompt_tokens, generated_tokens=None)
380
print(f"{type(sampler).__name__}: token {next_token}")
381
```