0
# Model Inference
1
2
Core inference functionality for running Transformer models with high performance. CTranslate2 provides three main inference classes: `Translator` for sequence-to-sequence models, `Generator` for language models, and `Encoder` for encoder-only models. All classes support batching, streaming, and asynchronous processing with advanced optimization techniques.
3
4
## Capabilities
5
6
### Translation (Seq2Seq Models)
7
8
The `Translator` class handles sequence-to-sequence models like T5, BART, and traditional Transformer models for tasks such as machine translation, summarization, and text-to-text generation.
9
10
```python { .api }
11
class Translator:
12
def __init__(self, model_path: str, device: str = "auto",
13
device_index: int = 0, compute_type: str = "default",
14
inter_threads: int = 1, intra_threads: int = 0,
15
max_queued_batches: int = 0, flash_attention: bool = False,
16
tensor_parallel: bool = False, files: dict = None):
17
"""
18
Initialize a Translator for sequence-to-sequence models.
19
20
Args:
21
model_path (str): Path to the CTranslate2 model directory
22
device (str): Device to run on ("cpu", "cuda", "auto")
23
device_index (int): Device index for multi-GPU setups
24
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
25
inter_threads (int): Number of inter-op threads
26
intra_threads (int): Number of intra-op threads (0 for auto)
27
max_queued_batches (int): Maximum number of batches in queue
28
flash_attention (bool): Enable Flash Attention optimization
29
tensor_parallel (bool): Enable tensor parallelism
30
files (dict): Additional model files mapping
31
"""
32
33
def translate_batch(self, source: list, target_prefix: list = None,
34
beam_size: int = 1, patience: float = 1.0,
35
length_penalty: float = 1.0, coverage_penalty: float = 0.0,
36
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
37
prefix_bias_beta: float = 0.0, max_length: int = 512,
38
min_length: int = 0, use_vmap: bool = False,
39
return_end_token: bool = False, max_input_length: int = 1024,
40
max_decoding_length: int = 256, min_decoding_length: int = 1,
41
sampling_topk: int = 1, sampling_topp: float = 1.0,
42
sampling_temperature: float = 1.0, return_scores: bool = False,
43
return_attention: bool = False, return_alternatives: bool = False,
44
min_alternative_expansion_prob: float = 0.0,
45
num_hypotheses: int = 1, **kwargs) -> list:
46
"""
47
Translate a batch of source sequences.
48
49
Args:
50
source (list): List of source sequences (each sequence is a list of tokens)
51
target_prefix (list, optional): List of target prefixes to condition generation
52
beam_size (int): Beam search size for decoding
53
patience (float): Beam search patience factor
54
length_penalty (float): Length penalty for beam search
55
coverage_penalty (float): Coverage penalty to avoid repetition
56
repetition_penalty (float): Repetition penalty for generated tokens
57
no_repeat_ngram_size (int): Size of n-grams that cannot be repeated
58
max_length (int): Maximum length of generated sequences
59
min_length (int): Minimum length of generated sequences
60
sampling_topk (int): Top-k sampling parameter
61
sampling_topp (float): Top-p (nucleus) sampling parameter
62
sampling_temperature (float): Temperature for sampling
63
return_scores (bool): Whether to return scores
64
return_attention (bool): Whether to return attention weights
65
num_hypotheses (int): Number of hypotheses to return per input
66
67
Returns:
68
list: List of TranslationResult objects
69
"""
70
71
def score_batch(self, source: list, target: list, max_input_length: int = 1024,
72
use_vmap: bool = False, **kwargs) -> list:
73
"""
74
Score a batch of source-target sequence pairs.
75
76
Args:
77
source (list): List of source sequences
78
target (list): List of target sequences to score
79
max_input_length (int): Maximum input sequence length
80
use_vmap (bool): Whether to use vocabulary mapping
81
82
Returns:
83
list: List of ScoringResult objects with scores
84
"""
85
86
def translate_iterable(self, source, target_prefix=None, batch_size: int = 32,
87
batch_type: str = "examples", **kwargs):
88
"""
89
Translate an iterable of source sequences with efficient batching.
90
91
Args:
92
source: Iterable of source sequences
93
target_prefix: Iterable of target prefixes (optional)
94
batch_size (int): Maximum batch size
95
batch_type (str): Batching strategy ("examples" or "tokens")
96
**kwargs: Additional arguments passed to translate_batch
97
98
Yields:
99
TranslationResult: Results for each input sequence
100
"""
101
102
def score_iterable(self, source, target, batch_size: int = 32,
103
batch_type: str = "examples", **kwargs):
104
"""
105
Score an iterable of source-target pairs with efficient batching.
106
107
Args:
108
source: Iterable of source sequences
109
target: Iterable of target sequences
110
batch_size (int): Maximum batch size
111
batch_type (str): Batching strategy ("examples" or "tokens")
112
**kwargs: Additional arguments passed to score_batch
113
114
Yields:
115
ScoringResult: Scoring results for each input pair
116
"""
117
118
def generate_tokens(self, source: list, target_prefix: list = None, **kwargs):
119
"""
120
Generate tokens step-by-step for a single input.
121
122
Args:
123
source (list): Source sequence as list of tokens
124
target_prefix (list, optional): Target prefix tokens
125
**kwargs: Additional generation parameters
126
127
Yields:
128
GenerationStepResult: Each generated token with metadata
129
"""
130
131
@property
132
def model_is_loaded(self) -> bool:
133
"""Whether the model is loaded in memory."""
134
135
@property
136
def device(self) -> str:
137
"""Device name where the model is running."""
138
139
@property
140
def device_index(self) -> list:
141
"""List of device indices being used."""
142
143
@property
144
def num_translators(self) -> int:
145
"""Number of translator instances."""
146
147
@property
148
def num_queued_batches(self) -> int:
149
"""Current number of queued batches."""
150
151
@property
152
def compute_type(self) -> str:
153
"""Compute type being used for inference."""
154
```
155
156
### Text Generation (Language Models)
157
158
The `Generator` class handles decoder-only language models like GPT-2, Llama, and Mistral for text generation, completion, and scoring tasks.
159
160
```python { .api }
161
class Generator:
162
def __init__(self, model_path: str, device: str = "auto",
163
device_index: int = 0, compute_type: str = "default",
164
inter_threads: int = 1, intra_threads: int = 0,
165
max_queued_batches: int = 0, flash_attention: bool = False,
166
tensor_parallel: bool = False, files: dict = None):
167
"""
168
Initialize a Generator for language models.
169
170
Args:
171
model_path (str): Path to the CTranslate2 model directory
172
device (str): Device to run on ("cpu", "cuda", "auto")
173
device_index (int): Device index for multi-GPU setups
174
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
175
inter_threads (int): Number of inter-op threads
176
intra_threads (int): Number of intra-op threads (0 for auto)
177
max_queued_batches (int): Maximum number of batches in queue
178
flash_attention (bool): Enable Flash Attention optimization
179
tensor_parallel (bool): Enable tensor parallelism
180
files (dict): Additional model files mapping
181
"""
182
183
def generate_batch(self, start_tokens: list, max_length: int = 512,
184
min_length: int = 0, sampling_topk: int = 1,
185
sampling_topp: float = 1.0, sampling_temperature: float = 1.0,
186
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
187
disable_unk: bool = False, suppress_sequences: list = None,
188
end_token: str = None, return_end_token: bool = False,
189
max_input_length: int = 1024, static_prompt: bool = False,
190
cache_static_prompt: bool = True, include_prompt_in_result: bool = True,
191
return_scores: bool = False, **kwargs) -> list:
192
"""
193
Generate sequences from a batch of start tokens.
194
195
Args:
196
start_tokens (list): List of start token sequences
197
max_length (int): Maximum length of generated sequences
198
min_length (int): Minimum length of generated sequences
199
sampling_topk (int): Top-k sampling parameter
200
sampling_topp (float): Top-p (nucleus) sampling parameter
201
sampling_temperature (float): Temperature for sampling
202
repetition_penalty (float): Repetition penalty for generated tokens
203
no_repeat_ngram_size (int): Size of n-grams that cannot be repeated
204
disable_unk (bool): Whether to disable unknown token generation
205
suppress_sequences (list): List of token sequences to suppress
206
end_token (str): Token that ends generation
207
return_end_token (bool): Whether to include end token in result
208
max_input_length (int): Maximum input sequence length
209
static_prompt (bool): Whether prompt is static across calls
210
cache_static_prompt (bool): Whether to cache static prompt
211
include_prompt_in_result (bool): Whether to include prompt in output
212
return_scores (bool): Whether to return generation scores
213
214
Returns:
215
list: List of GenerationResult objects
216
"""
217
218
def score_batch(self, tokens: list, max_length: int = 1024, **kwargs) -> list:
219
"""
220
Score a batch of token sequences.
221
222
Args:
223
tokens (list): List of token sequences to score
224
max_length (int): Maximum sequence length to consider
225
226
Returns:
227
list: List of ScoringResult objects with scores
228
"""
229
230
def generate_iterable(self, start_tokens, batch_size: int = 32,
231
batch_type: str = "examples", **kwargs):
232
"""
233
Generate from an iterable of start token sequences with efficient batching.
234
235
Args:
236
start_tokens: Iterable of start token sequences
237
batch_size (int): Maximum batch size
238
batch_type (str): Batching strategy ("examples" or "tokens")
239
**kwargs: Additional arguments passed to generate_batch
240
241
Yields:
242
GenerationResult: Results for each input sequence
243
"""
244
245
def score_iterable(self, tokens, batch_size: int = 32,
246
batch_type: str = "examples", **kwargs):
247
"""
248
Score an iterable of token sequences with efficient batching.
249
250
Args:
251
tokens: Iterable of token sequences
252
batch_size (int): Maximum batch size
253
batch_type (str): Batching strategy ("examples" or "tokens")
254
**kwargs: Additional arguments passed to score_batch
255
256
Yields:
257
ScoringResult: Scoring results for each input sequence
258
"""
259
260
def generate_tokens(self, prompt: list, **kwargs):
261
"""
262
Generate tokens step-by-step for a single prompt.
263
264
Args:
265
prompt (list): Prompt tokens as list
266
**kwargs: Additional generation parameters
267
268
Yields:
269
GenerationStepResult: Each generated token with metadata
270
"""
271
272
def async_generate_tokens(self, prompt: list, **kwargs):
273
"""
274
Generate tokens asynchronously step-by-step for a single prompt.
275
276
Args:
277
prompt (list): Prompt tokens as list
278
**kwargs: Additional generation parameters
279
280
Returns:
281
AsyncGenerationResult: Async result object for streaming
282
"""
283
```
284
285
### Encoding (Encoder-Only Models)
286
287
The `Encoder` class handles encoder-only models like BERT and RoBERTa for feature extraction and representation learning tasks.
288
289
```python { .api }
290
class Encoder:
291
def __init__(self, model_path: str, device: str = "auto",
292
device_index: int = 0, compute_type: str = "default",
293
inter_threads: int = 1, intra_threads: int = 0,
294
max_queued_batches: int = 0, files: dict = None):
295
"""
296
Initialize an Encoder for encoder-only models.
297
298
Args:
299
model_path (str): Path to the CTranslate2 model directory
300
device (str): Device to run on ("cpu", "cuda", "auto")
301
device_index (int): Device index for multi-GPU setups
302
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
303
inter_threads (int): Number of inter-op threads
304
intra_threads (int): Number of intra-op threads (0 for auto)
305
max_queued_batches (int): Maximum number of batches in queue
306
files (dict): Additional model files mapping
307
"""
308
309
def forward_batch(self, inputs: list, normalize: bool = False,
310
max_input_length: int = 1024, **kwargs) -> list:
311
"""
312
Forward pass on a batch of input sequences.
313
314
Args:
315
inputs (list): List of input token sequences
316
normalize (bool): Whether to normalize output embeddings
317
max_input_length (int): Maximum input sequence length
318
319
Returns:
320
list: List of EncoderForwardOutput objects
321
"""
322
```
323
324
## Usage Examples
325
326
### Basic Translation
327
328
```python
329
import ctranslate2
330
331
# Load a translation model
332
translator = ctranslate2.Translator("path/to/model", device="cpu")
333
334
# Translate single sentence
335
source = [["Hello", "world", "!"]]
336
results = translator.translate_batch(source)
337
print(results[0].hypotheses[0]) # ['Bonjour', 'le', 'monde', '!']
338
339
# Translate with beam search
340
results = translator.translate_batch(source, beam_size=4, num_hypotheses=2)
341
for i, hypothesis in enumerate(results[0].hypotheses):
342
score = results[0].scores[i]
343
print(f"Hypothesis {i+1} (score: {score:.4f}): {' '.join(hypothesis)}")
344
```
345
346
### Text Generation
347
348
```python
349
import ctranslate2
350
351
# Load a language model
352
generator = ctranslate2.Generator("path/to/model", device="cpu")
353
354
# Generate text
355
prompt = [["The", "quick", "brown", "fox"]]
356
results = generator.generate_batch(prompt, max_length=50, sampling_temperature=0.8)
357
print(" ".join(results[0].sequences[0]))
358
359
# Step-by-step generation
360
for step in generator.generate_tokens(["The", "quick", "brown"]):
361
print(f"Token: {step.token}, Probability: {step.log_prob:.4f}")
362
if step.is_last:
363
break
364
```
365
366
### Streaming Processing
367
368
```python
369
import ctranslate2
370
371
translator = ctranslate2.Translator("path/to/model")
372
373
# Process large dataset efficiently
374
source_sentences = [["sentence", "1"], ["sentence", "2"], ...] # Large list
375
376
# Stream processing with batching
377
for result in translator.translate_iterable(source_sentences, batch_size=32):
378
translated = " ".join(result.hypotheses[0])
379
print(translated)
380
```
381
382
## Types
383
384
```python { .api }
385
class TranslationResult:
386
"""Result from translation operations."""
387
hypotheses: list[list[str]] # List of hypothesis token sequences
388
scores: list[float] # Scores for each hypothesis
389
attention: list # Attention weights (if requested)
390
391
class GenerationResult:
392
"""Result from generation operations."""
393
sequences: list[list[str]] # Generated token sequences
394
scores: list[float] # Generation scores
395
sequences_ids: list[list[int]] # Token IDs for generated sequences
396
397
class ScoringResult:
398
"""Result from scoring operations."""
399
scores: list[float] # Log probabilities for each sequence
400
tokens_count: list[int] # Token counts for each sequence
401
402
class GenerationStepResult:
403
"""Result from step-by-step generation."""
404
token: str # Generated token
405
token_id: int # Token ID
406
is_last: bool # Whether this is the last token
407
log_prob: float # Log probability of the token
408
409
class EncoderForwardOutput:
410
"""Output from encoder forward pass."""
411
last_hidden_state: StorageView # Final hidden states
412
pooler_output: StorageView # Pooled output (if available)
413
414
class AsyncTranslationResult:
415
"""Async result wrapper for translation."""
416
def result(self) -> TranslationResult: ...
417
def is_done(self) -> bool: ...
418
419
class AsyncGenerationResult:
420
"""Async result wrapper for generation."""
421
def result(self) -> GenerationResult: ...
422
def is_done(self) -> bool: ...
423
424
class AsyncScoringResult:
425
"""Async result wrapper for scoring."""
426
def result(self) -> ScoringResult: ...
427
def is_done(self) -> bool: ...
428
```