or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio-models.mdevaluation-metrics.mdgenerative-models.mdimage-models.mdindex.mdlayers-components.mdmultimodal-models.mdtext-generation-sampling.mdtext-models.mdtokenizers.mdutilities-helpers.md

text-generation-sampling.mddocs/

0

# Text Generation Sampling

1

2

Sampling strategies for controlling text generation behavior in language models. Keras Hub provides various sampling methods to balance between quality, diversity, and controllability in generated text.

3

4

## Capabilities

5

6

### Base Classes

7

8

Foundation classes for text generation sampling.

9

10

```python { .api }

11

class Sampler:

12

"""Base class for all samplers."""

13

def __init__(self, **kwargs): ...

14

15

def __call__(

16

self,

17

next_token_logits,

18

prompt_tokens,

19

generated_tokens,

20

**kwargs

21

): ...

22

23

def get_next_token(self, probabilities): ...

24

```

25

26

### Deterministic Sampling

27

28

Samplers that produce deterministic outputs given the same input.

29

30

```python { .api }

31

class GreedySampler(Sampler):

32

"""

33

Greedy sampling always selects the token with highest probability.

34

Produces deterministic but potentially repetitive outputs.

35

"""

36

def __init__(self, **kwargs): ...

37

38

class BeamSampler(Sampler):

39

"""

40

Beam search maintains multiple candidate sequences and selects

41

the sequence with highest overall probability.

42

"""

43

def __init__(

44

self,

45

num_beams: int = 5,

46

return_all_beams: bool = False,

47

**kwargs

48

): ...

49

```

50

51

### Stochastic Sampling

52

53

Samplers that introduce randomness for more diverse outputs.

54

55

```python { .api }

56

class RandomSampler(Sampler):

57

"""

58

Random sampling selects tokens according to their probability distribution.

59

Higher temperature increases randomness.

60

"""

61

def __init__(

62

self,

63

temperature: float = 1.0,

64

seed: int = None,

65

**kwargs

66

): ...

67

68

class TopKSampler(Sampler):

69

"""

70

Top-k sampling considers only the k most likely tokens at each step.

71

Balances quality and diversity by filtering low-probability tokens.

72

"""

73

def __init__(

74

self,

75

k: int = 50,

76

temperature: float = 1.0,

77

seed: int = None,

78

**kwargs

79

): ...

80

81

class TopPSampler(Sampler):

82

"""

83

Top-p (nucleus) sampling considers tokens whose cumulative probability

84

is within the top p fraction. Adapts the number of considered tokens

85

based on the probability distribution.

86

"""

87

def __init__(

88

self,

89

p: float = 0.9,

90

temperature: float = 1.0,

91

seed: int = None,

92

**kwargs

93

): ...

94

```

95

96

### Advanced Sampling

97

98

More sophisticated sampling strategies for improved generation quality.

99

100

```python { .api }

101

class ContrastiveSampler(Sampler):

102

"""

103

Contrastive search balances high probability and low repetition

104

by penalizing tokens that are too similar to previously generated tokens.

105

"""

106

def __init__(

107

self,

108

k: int = 4,

109

alpha: float = 0.6,

110

**kwargs

111

): ...

112

```

113

114

### Sampler Utilities

115

116

Utilities for working with samplers programmatically.

117

118

```python { .api }

119

def serialize(sampler: Sampler) -> dict:

120

"""

121

Serialize a sampler instance to a dictionary.

122

123

Args:

124

sampler: The sampler instance to serialize

125

126

Returns:

127

Dictionary representation of the sampler

128

"""

129

...

130

131

def deserialize(config: dict) -> Sampler:

132

"""

133

Deserialize a sampler from a dictionary configuration.

134

135

Args:

136

config: Dictionary configuration of the sampler

137

138

Returns:

139

Sampler instance

140

"""

141

...

142

143

def get(identifier) -> Sampler:

144

"""

145

Get a sampler by name or return existing sampler instance.

146

147

Args:

148

identifier: String name or sampler instance

149

150

Returns:

151

Sampler instance

152

"""

153

...

154

```

155

156

## Usage Examples

157

158

### Greedy Sampling for Deterministic Output

159

160

```python

161

import keras_hub

162

163

# Load model

164

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

165

166

# Create greedy sampler

167

sampler = keras_hub.samplers.GreedySampler()

168

169

# Generate text deterministically

170

prompt = "The future of artificial intelligence"

171

output = model.generate(prompt, max_length=50, sampler=sampler)

172

print("Greedy output:", output)

173

```

174

175

### Random Sampling with Temperature Control

176

177

```python

178

import keras_hub

179

180

# Load model

181

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

182

183

# Low temperature for more focused generation

184

low_temp_sampler = keras_hub.samplers.RandomSampler(temperature=0.3)

185

output_focused = model.generate(

186

"The weather today is",

187

max_length=30,

188

sampler=low_temp_sampler

189

)

190

191

# High temperature for more creative generation

192

high_temp_sampler = keras_hub.samplers.RandomSampler(temperature=1.5)

193

output_creative = model.generate(

194

"The weather today is",

195

max_length=30,

196

sampler=high_temp_sampler

197

)

198

199

print("Focused output:", output_focused)

200

print("Creative output:", output_creative)

201

```

202

203

### Top-k Sampling for Quality-Diversity Balance

204

205

```python

206

import keras_hub

207

208

# Load model

209

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

210

211

# Top-k sampling with different k values

212

small_k_sampler = keras_hub.samplers.TopKSampler(k=10, temperature=0.8)

213

large_k_sampler = keras_hub.samplers.TopKSampler(k=100, temperature=0.8)

214

215

prompt = "In the distant future"

216

217

# More conservative generation (smaller k)

218

output_conservative = model.generate(prompt, max_length=40, sampler=small_k_sampler)

219

220

# More diverse generation (larger k)

221

output_diverse = model.generate(prompt, max_length=40, sampler=large_k_sampler)

222

223

print("Conservative (k=10):", output_conservative)

224

print("Diverse (k=100):", output_diverse)

225

```

226

227

### Top-p (Nucleus) Sampling

228

229

```python

230

import keras_hub

231

232

# Load model

233

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

234

235

# Top-p sampling adapts to probability distribution

236

sampler = keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8)

237

238

# Generate multiple outputs to see diversity

239

prompt = "Once upon a time"

240

for i in range(3):

241

output = model.generate(prompt, max_length=25, sampler=sampler)

242

print(f"Output {i+1}: {output}")

243

```

244

245

### Beam Search for Best Overall Sequence

246

247

```python

248

import keras_hub

249

250

# Load model

251

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

252

253

# Beam search with different beam sizes

254

beam_sampler = keras_hub.samplers.BeamSampler(

255

num_beams=5,

256

return_all_beams=False # Return only best beam

257

)

258

259

prompt = "The most important discovery in science"

260

output = model.generate(prompt, max_length=35, sampler=beam_sampler)

261

print("Beam search output:", output)

262

263

# Return all beams to see alternatives

264

all_beams_sampler = keras_hub.samplers.BeamSampler(

265

num_beams=3,

266

return_all_beams=True

267

)

268

269

all_outputs = model.generate(prompt, max_length=25, sampler=all_beams_sampler)

270

for i, beam_output in enumerate(all_outputs):

271

print(f"Beam {i+1}: {beam_output}")

272

```

273

274

### Contrastive Search for Reducing Repetition

275

276

```python

277

import keras_hub

278

279

# Load model

280

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

281

282

# Contrastive search balances probability and novelty

283

sampler = keras_hub.samplers.ContrastiveSampler(

284

k=4, # Number of top tokens to consider

285

alpha=0.6 # Balance between probability and novelty

286

)

287

288

prompt = "Artificial intelligence will change the world by"

289

output = model.generate(prompt, max_length=50, sampler=sampler)

290

print("Contrastive search output:", output)

291

```

292

293

### Comparing Different Sampling Methods

294

295

```python

296

import keras_hub

297

298

# Load model

299

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

300

301

# Define different samplers

302

samplers = {

303

"Greedy": keras_hub.samplers.GreedySampler(),

304

"Random (T=0.8)": keras_hub.samplers.RandomSampler(temperature=0.8),

305

"Top-k (k=50)": keras_hub.samplers.TopKSampler(k=50, temperature=0.8),

306

"Top-p (p=0.9)": keras_hub.samplers.TopPSampler(p=0.9, temperature=0.8),

307

"Contrastive": keras_hub.samplers.ContrastiveSampler(k=4, alpha=0.6)

308

}

309

310

prompt = "The key to happiness is"

311

312

# Generate with each sampler

313

for name, sampler in samplers.items():

314

output = model.generate(prompt, max_length=30, sampler=sampler)

315

print(f"{name}: {output}")

316

```

317

318

### Serializing and Deserializing Samplers

319

320

```python

321

import keras_hub

322

323

# Create a sampler

324

original_sampler = keras_hub.samplers.TopKSampler(k=40, temperature=0.7)

325

326

# Serialize to dictionary

327

config = keras_hub.samplers.serialize(original_sampler)

328

print("Serialized config:", config)

329

330

# Deserialize back to sampler

331

restored_sampler = keras_hub.samplers.deserialize(config)

332

333

# Use restored sampler

334

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

335

output = model.generate("Hello world", max_length=20, sampler=restored_sampler)

336

print("Generated with restored sampler:", output)

337

```

338

339

### Getting Samplers by Name

340

341

```python

342

import keras_hub

343

344

# Get sampler by string identifier

345

greedy = keras_hub.samplers.get("greedy")

346

random = keras_hub.samplers.get("random")

347

348

# Get existing sampler instance (returns same instance)

349

top_k = keras_hub.samplers.TopKSampler(k=50)

350

same_sampler = keras_hub.samplers.get(top_k)

351

352

print("Greedy sampler:", type(greedy).__name__)

353

print("Random sampler:", type(random).__name__)

354

print("Same instance:", top_k is same_sampler)

355

```

356

357

### Custom Sampling with Manual Control

358

359

```python

360

import keras_hub

361

import numpy as np

362

363

# Load model and get logits manually

364

model = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

365

366

# Get next token logits for a prompt

367

prompt_tokens = model.preprocessor.tokenizer(["Hello world"])

368

logits = model.backbone(prompt_tokens)[:, -1, :] # Last token logits

369

370

# Apply different samplers to the same logits

371

samplers = [

372

keras_hub.samplers.GreedySampler(),

373

keras_hub.samplers.TopKSampler(k=10),

374

keras_hub.samplers.TopPSampler(p=0.8)

375

]

376

377

for sampler in samplers:

378

# Sample next token

379

next_token = sampler(logits, prompt_tokens, generated_tokens=None)

380

print(f"{type(sampler).__name__}: token {next_token}")

381

```