0
# Specialized Models
1
2
Domain-specific model classes for speech recognition and audio processing tasks. CTranslate2 provides optimized implementations for Whisper (speech-to-text), Wav2Vec2 (speech representation learning), and Wav2Vec2Bert (enhanced speech processing) with the same performance optimizations as the core inference classes.
3
4
## Capabilities
5
6
### Whisper Speech Recognition
7
8
The `Whisper` class provides optimized inference for OpenAI's Whisper automatic speech recognition models, supporting transcription, translation, and language detection.
9
10
```python { .api }
11
class Whisper:
12
def __init__(self, model_path: str, device: str = "auto",
13
device_index: int = 0, compute_type: str = "default",
14
inter_threads: int = 1, intra_threads: int = 0,
15
max_queued_batches: int = 0, files: dict = None):
16
"""
17
Initialize Whisper model for speech recognition.
18
19
Args:
20
model_path (str): Path to the CTranslate2 Whisper model directory
21
device (str): Device to run on ("cpu", "cuda", "auto")
22
device_index (int): Device index for multi-GPU setups
23
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
24
inter_threads (int): Number of inter-op threads
25
intra_threads (int): Number of intra-op threads (0 for auto)
26
max_queued_batches (int): Maximum number of batches in queue
27
files (dict): Additional model files mapping
28
"""
29
30
def transcribe(self, features: list, language: str = None,
31
task: str = "transcribe", beam_size: int = 5,
32
patience: float = 1.0, length_penalty: float = 1.0,
33
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
34
temperature: float = 1.0, compression_ratio_threshold: float = 2.4,
35
log_prob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
36
condition_on_previous_text: bool = True,
37
prompt_reset_on_temperature: float = 0.5,
38
initial_prompt: str = None, prefix: str = None,
39
suppress_blank: bool = True, suppress_tokens: list = None,
40
without_timestamps: bool = False, max_initial_timestamp: float = 1.0,
41
word_timestamps: bool = False, prepend_punctuations: str = "\"'"¿([{-",
42
append_punctuations: str = "\"'.。,,!!??::")]}、",
43
vad_filter: bool = False, vad_parameters: dict = None,
44
max_new_tokens: int = None, clip_timestamps: list = None,
45
hallucination_silence_threshold: float = None,
46
hotwords: str = None, language_detection_threshold: float = None,
47
language_detection_segments: int = 1, **kwargs) -> list:
48
"""
49
Transcribe audio features to text.
50
51
Args:
52
features (list): List of mel-spectrogram features
53
language (str): Language code (e.g., "en", "fr", "de")
54
task (str): Task type ("transcribe" or "translate")
55
beam_size (int): Beam search size
56
patience (float): Beam search patience
57
length_penalty (float): Length penalty for beam search
58
repetition_penalty (float): Repetition penalty
59
no_repeat_ngram_size (int): N-gram size to avoid repetition
60
temperature (float): Sampling temperature
61
compression_ratio_threshold (float): Threshold for compression ratio
62
log_prob_threshold (float): Log probability threshold
63
no_speech_threshold (float): No speech detection threshold
64
condition_on_previous_text (bool): Use previous text as context
65
prompt_reset_on_temperature (float): Reset prompt at temperature
66
initial_prompt (str): Initial prompt text
67
prefix (str): Prefix for generated text
68
suppress_blank (bool): Suppress blank tokens
69
suppress_tokens (list): List of tokens to suppress
70
without_timestamps (bool): Generate without timestamps
71
max_initial_timestamp (float): Maximum initial timestamp
72
word_timestamps (bool): Generate word-level timestamps
73
prepend_punctuations (str): Punctuations to prepend
74
append_punctuations (str): Punctuations to append
75
vad_filter (bool): Apply voice activity detection filter
76
vad_parameters (dict): VAD configuration parameters
77
max_new_tokens (int): Maximum new tokens to generate
78
clip_timestamps (list): Timestamp clipping range
79
hallucination_silence_threshold (float): Silence threshold for hallucination detection
80
hotwords (str): Hotwords for biased generation
81
language_detection_threshold (float): Threshold for language detection
82
language_detection_segments (int): Number of segments for language detection
83
84
Returns:
85
list: List of WhisperGenerationResult objects
86
"""
87
88
def detect_language(self, features: list, **kwargs) -> list:
89
"""
90
Detect language from audio features.
91
92
Args:
93
features (list): List of mel-spectrogram features
94
**kwargs: Additional detection parameters
95
96
Returns:
97
list: List of detected languages with probabilities
98
"""
99
100
def generate(self, features: list, prompts: list = None, **kwargs) -> list:
101
"""
102
Generate text from audio features with optional prompts.
103
104
Args:
105
features (list): List of mel-spectrogram features
106
prompts (list): List of text prompts
107
**kwargs: Additional generation parameters
108
109
Returns:
110
list: List of generation results
111
"""
112
```
113
114
### Wav2Vec2 Speech Representation
115
116
The `Wav2Vec2` class provides inference for Facebook's Wav2Vec2 models for speech representation learning and feature extraction.
117
118
```python { .api }
119
class Wav2Vec2:
120
def __init__(self, model_path: str, device: str = "auto",
121
device_index: int = 0, compute_type: str = "default",
122
inter_threads: int = 1, intra_threads: int = 0,
123
max_queued_batches: int = 0, files: dict = None):
124
"""
125
Initialize Wav2Vec2 model for speech processing.
126
127
Args:
128
model_path (str): Path to the CTranslate2 Wav2Vec2 model directory
129
device (str): Device to run on ("cpu", "cuda", "auto")
130
device_index (int): Device index for multi-GPU setups
131
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
132
inter_threads (int): Number of inter-op threads
133
intra_threads (int): Number of intra-op threads (0 for auto)
134
max_queued_batches (int): Maximum number of batches in queue
135
files (dict): Additional model files mapping
136
"""
137
138
def encode(self, features: list, normalize: bool = False,
139
return_hidden: bool = False, **kwargs) -> list:
140
"""
141
Encode audio features using Wav2Vec2.
142
143
Args:
144
features (list): List of raw audio waveforms or features
145
normalize (bool): Whether to normalize output representations
146
return_hidden (bool): Whether to return hidden states
147
**kwargs: Additional encoding parameters
148
149
Returns:
150
list: List of encoded representations
151
"""
152
153
def forward_batch(self, inputs: list, **kwargs) -> list:
154
"""
155
Forward pass on a batch of audio inputs.
156
157
Args:
158
inputs (list): List of audio input sequences
159
**kwargs: Additional forward pass parameters
160
161
Returns:
162
list: List of forward pass outputs
163
"""
164
```
165
166
### Wav2Vec2Bert Enhanced Speech Processing
167
168
The `Wav2Vec2Bert` class provides inference for the enhanced Wav2Vec2-BERT models that combine speech representation learning with BERT-style pretraining.
169
170
```python { .api }
171
class Wav2Vec2Bert:
172
def __init__(self, model_path: str, device: str = "auto",
173
device_index: int = 0, compute_type: str = "default",
174
inter_threads: int = 1, intra_threads: int = 0,
175
max_queued_batches: int = 0, files: dict = None):
176
"""
177
Initialize Wav2Vec2Bert model for enhanced speech processing.
178
179
Args:
180
model_path (str): Path to the CTranslate2 Wav2Vec2Bert model directory
181
device (str): Device to run on ("cpu", "cuda", "auto")
182
device_index (int): Device index for multi-GPU setups
183
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
184
inter_threads (int): Number of inter-op threads
185
intra_threads (int): Number of intra-op threads (0 for auto)
186
max_queued_batches (int): Maximum number of batches in queue
187
files (dict): Additional model files mapping
188
"""
189
190
def encode(self, features: list, normalize: bool = False,
191
return_hidden: bool = False, **kwargs) -> list:
192
"""
193
Encode audio features using Wav2Vec2Bert.
194
195
Args:
196
features (list): List of raw audio waveforms or features
197
normalize (bool): Whether to normalize output representations
198
return_hidden (bool): Whether to return hidden states
199
**kwargs: Additional encoding parameters
200
201
Returns:
202
list: List of encoded representations
203
"""
204
205
def forward_batch(self, inputs: list, **kwargs) -> list:
206
"""
207
Forward pass on a batch of audio inputs.
208
209
Args:
210
inputs (list): List of audio input sequences
211
**kwargs: Additional forward pass parameters
212
213
Returns:
214
list: List of forward pass outputs
215
"""
216
```
217
218
## Usage Examples
219
220
### Whisper Speech-to-Text
221
222
```python
223
import ctranslate2
224
225
# Load Whisper model
226
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model", device="cpu")
227
228
# Prepare audio features (mel-spectrograms)
229
# Features should be mel-spectrograms with shape (80, time_steps)
230
audio_features = [mel_spectrogram_1, mel_spectrogram_2] # List of numpy arrays
231
232
# Transcribe audio
233
results = whisper.transcribe(audio_features, language="en", task="transcribe")
234
235
for result in results:
236
print("Transcription:", result.sequences[0])
237
if hasattr(result, 'timestamps') and result.timestamps:
238
print("Timestamps:", result.timestamps)
239
240
# Transcribe with word-level timestamps
241
results = whisper.transcribe(
242
audio_features,
243
language="en",
244
word_timestamps=True,
245
without_timestamps=False
246
)
247
248
for result in results:
249
print("Text:", result.sequences[0])
250
for word_info in result.word_timestamps:
251
print(f"Word: {word_info['word']}, Start: {word_info['start']:.2f}s, End: {word_info['end']:.2f}s")
252
```
253
254
### Language Detection
255
256
```python
257
import ctranslate2
258
259
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
260
261
# Detect language from audio
262
language_results = whisper.detect_language(audio_features)
263
264
for result in language_results:
265
detected_language = result.language
266
confidence = result.language_probability
267
print(f"Detected language: {detected_language} (confidence: {confidence:.3f})")
268
```
269
270
### Translation Task
271
272
```python
273
import ctranslate2
274
275
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
276
277
# Translate foreign speech to English
278
results = whisper.transcribe(
279
audio_features,
280
task="translate", # Translate to English
281
language="fr" # Source language is French
282
)
283
284
for result in results:
285
print("English translation:", result.sequences[0])
286
```
287
288
### Wav2Vec2 Feature Extraction
289
290
```python
291
import ctranslate2
292
import numpy as np
293
294
# Load Wav2Vec2 model
295
wav2vec2 = ctranslate2.models.Wav2Vec2("path/to/wav2vec2_ct2_model", device="cpu")
296
297
# Prepare raw audio waveforms
298
# Audio should be 16kHz mono waveforms
299
audio_waveforms = [waveform_1, waveform_2] # List of numpy arrays
300
301
# Extract speech representations
302
representations = wav2vec2.encode(audio_waveforms, normalize=True)
303
304
for repr in representations:
305
print("Representation shape:", repr.shape)
306
# Use representations for downstream tasks like speaker recognition,
307
# emotion detection, or as features for other models
308
```
309
310
### Wav2Vec2Bert Processing
311
312
```python
313
import ctranslate2
314
315
# Load Wav2Vec2Bert model
316
wav2vec2bert = ctranslate2.models.Wav2Vec2Bert("path/to/wav2vec2bert_ct2_model")
317
318
# Extract enhanced representations
319
enhanced_representations = wav2vec2bert.encode(
320
audio_waveforms,
321
normalize=True,
322
return_hidden=True
323
)
324
325
for repr in enhanced_representations:
326
print("Enhanced representation shape:", repr.shape)
327
# These representations combine speech and language understanding
328
```
329
330
### Batch Processing for Efficiency
331
332
```python
333
import ctranslate2
334
335
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model", device="cuda")
336
337
# Process multiple audio files efficiently
338
batch_features = [features_1, features_2, features_3, features_4]
339
340
# Batch transcription
341
batch_results = whisper.transcribe(
342
batch_features,
343
language="en",
344
beam_size=5,
345
temperature=0.0 # Deterministic output
346
)
347
348
for i, result in enumerate(batch_results):
349
print(f"Audio {i+1}: {result.sequences[0]}")
350
```
351
352
### Advanced Whisper Configuration
353
354
```python
355
import ctranslate2
356
357
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
358
359
# Advanced transcription with custom parameters
360
results = whisper.transcribe(
361
audio_features,
362
language="en",
363
task="transcribe",
364
beam_size=10,
365
temperature=0.2,
366
compression_ratio_threshold=2.4,
367
log_prob_threshold=-1.0,
368
no_speech_threshold=0.6,
369
condition_on_previous_text=True,
370
initial_prompt="This is a technical presentation about machine learning.",
371
suppress_tokens=[50256, 50257], # Suppress specific tokens
372
word_timestamps=True,
373
vad_filter=True,
374
vad_parameters={
375
"threshold": 0.5,
376
"min_speech_duration_ms": 250,
377
"max_speech_duration_s": 30
378
}
379
)
380
```
381
382
## Types
383
384
```python { .api }
385
class WhisperGenerationResult:
386
"""Result from Whisper transcription/translation."""
387
sequences: list[list[str]] # Generated text sequences
388
scores: list[float] # Generation scores
389
language: str # Detected/specified language
390
language_probability: float # Language detection confidence
391
timestamps: list[dict] # Segment-level timestamps
392
word_timestamps: list[dict] # Word-level timestamps (if requested)
393
avg_logprob: float # Average log probability
394
compression_ratio: float # Compression ratio metric
395
no_speech_prob: float # No speech probability
396
397
class WhisperGenerationResultAsync:
398
"""Async result wrapper for Whisper operations."""
399
def result(self) -> WhisperGenerationResult: ...
400
def is_done(self) -> bool: ...
401
402
# Whisper-specific configuration structures
403
class WhisperTimestamp:
404
"""Word or segment timestamp information."""
405
start: float # Start time in seconds
406
end: float # End time in seconds
407
word: str # Word text (for word timestamps)
408
probability: float # Confidence score
409
410
class WhisperSegment:
411
"""Transcription segment with metadata."""
412
text: str # Segment text
413
start: float # Start time in seconds
414
end: float # End time in seconds
415
tokens: list[int] # Token IDs
416
temperature: float # Generation temperature used
417
avg_logprob: float # Average log probability
418
compression_ratio: float # Compression ratio
419
no_speech_prob: float # No speech probability
420
421
# Wav2Vec2 result types
422
class Wav2Vec2Output:
423
"""Output from Wav2Vec2 encoding."""
424
representations: StorageView # Learned speech representations
425
hidden_states: list # Hidden states (if requested)
426
attention_weights: list # Attention weights (if available)
427
428
class Wav2Vec2BertOutput:
429
"""Output from Wav2Vec2Bert encoding."""
430
representations: StorageView # Enhanced speech representations
431
hidden_states: list # Hidden states from all layers
432
attention_weights: list # Attention weights
433
adapter_outputs: list # Adapter layer outputs
434
```