Fast inference engine for Transformer models
—
Domain-specific model classes for speech recognition and audio processing tasks. CTranslate2 provides optimized implementations for Whisper (speech-to-text), Wav2Vec2 (speech representation learning), and Wav2Vec2Bert (enhanced speech processing) with the same performance optimizations as the core inference classes.
The Whisper class provides optimized inference for OpenAI's Whisper automatic speech recognition models, supporting transcription, translation, and language detection.
class Whisper:
def __init__(self, model_path: str, device: str = "auto",
device_index: int = 0, compute_type: str = "default",
inter_threads: int = 1, intra_threads: int = 0,
max_queued_batches: int = 0, files: dict = None):
"""
Initialize Whisper model for speech recognition.
Args:
model_path (str): Path to the CTranslate2 Whisper model directory
device (str): Device to run on ("cpu", "cuda", "auto")
device_index (int): Device index for multi-GPU setups
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
inter_threads (int): Number of inter-op threads
intra_threads (int): Number of intra-op threads (0 for auto)
max_queued_batches (int): Maximum number of batches in queue
files (dict): Additional model files mapping
"""
def transcribe(self, features: list, language: str = None,
task: str = "transcribe", beam_size: int = 5,
patience: float = 1.0, length_penalty: float = 1.0,
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
temperature: float = 1.0, compression_ratio_threshold: float = 2.4,
log_prob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
initial_prompt: str = None, prefix: str = None,
suppress_blank: bool = True, suppress_tokens: list = None,
without_timestamps: bool = False, max_initial_timestamp: float = 1.0,
word_timestamps: bool = False, prepend_punctuations: str = "\"'"¿([{-",
append_punctuations: str = "\"'.。,,!!??::")]}、",
vad_filter: bool = False, vad_parameters: dict = None,
max_new_tokens: int = None, clip_timestamps: list = None,
hallucination_silence_threshold: float = None,
hotwords: str = None, language_detection_threshold: float = None,
language_detection_segments: int = 1, **kwargs) -> list:
"""
Transcribe audio features to text.
Args:
features (list): List of mel-spectrogram features
language (str): Language code (e.g., "en", "fr", "de")
task (str): Task type ("transcribe" or "translate")
beam_size (int): Beam search size
patience (float): Beam search patience
length_penalty (float): Length penalty for beam search
repetition_penalty (float): Repetition penalty
no_repeat_ngram_size (int): N-gram size to avoid repetition
temperature (float): Sampling temperature
compression_ratio_threshold (float): Threshold for compression ratio
log_prob_threshold (float): Log probability threshold
no_speech_threshold (float): No speech detection threshold
condition_on_previous_text (bool): Use previous text as context
prompt_reset_on_temperature (float): Reset prompt at temperature
initial_prompt (str): Initial prompt text
prefix (str): Prefix for generated text
suppress_blank (bool): Suppress blank tokens
suppress_tokens (list): List of tokens to suppress
without_timestamps (bool): Generate without timestamps
max_initial_timestamp (float): Maximum initial timestamp
word_timestamps (bool): Generate word-level timestamps
prepend_punctuations (str): Punctuations to prepend
append_punctuations (str): Punctuations to append
vad_filter (bool): Apply voice activity detection filter
vad_parameters (dict): VAD configuration parameters
max_new_tokens (int): Maximum new tokens to generate
clip_timestamps (list): Timestamp clipping range
hallucination_silence_threshold (float): Silence threshold for hallucination detection
hotwords (str): Hotwords for biased generation
language_detection_threshold (float): Threshold for language detection
language_detection_segments (int): Number of segments for language detection
Returns:
list: List of WhisperGenerationResult objects
"""
def detect_language(self, features: list, **kwargs) -> list:
"""
Detect language from audio features.
Args:
features (list): List of mel-spectrogram features
**kwargs: Additional detection parameters
Returns:
list: List of detected languages with probabilities
"""
def generate(self, features: list, prompts: list = None, **kwargs) -> list:
"""
Generate text from audio features with optional prompts.
Args:
features (list): List of mel-spectrogram features
prompts (list): List of text prompts
**kwargs: Additional generation parameters
Returns:
list: List of generation results
"""The Wav2Vec2 class provides inference for Facebook's Wav2Vec2 models for speech representation learning and feature extraction.
class Wav2Vec2:
def __init__(self, model_path: str, device: str = "auto",
device_index: int = 0, compute_type: str = "default",
inter_threads: int = 1, intra_threads: int = 0,
max_queued_batches: int = 0, files: dict = None):
"""
Initialize Wav2Vec2 model for speech processing.
Args:
model_path (str): Path to the CTranslate2 Wav2Vec2 model directory
device (str): Device to run on ("cpu", "cuda", "auto")
device_index (int): Device index for multi-GPU setups
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
inter_threads (int): Number of inter-op threads
intra_threads (int): Number of intra-op threads (0 for auto)
max_queued_batches (int): Maximum number of batches in queue
files (dict): Additional model files mapping
"""
def encode(self, features: list, normalize: bool = False,
return_hidden: bool = False, **kwargs) -> list:
"""
Encode audio features using Wav2Vec2.
Args:
features (list): List of raw audio waveforms or features
normalize (bool): Whether to normalize output representations
return_hidden (bool): Whether to return hidden states
**kwargs: Additional encoding parameters
Returns:
list: List of encoded representations
"""
def forward_batch(self, inputs: list, **kwargs) -> list:
"""
Forward pass on a batch of audio inputs.
Args:
inputs (list): List of audio input sequences
**kwargs: Additional forward pass parameters
Returns:
list: List of forward pass outputs
"""The Wav2Vec2Bert class provides inference for the enhanced Wav2Vec2-BERT models that combine speech representation learning with BERT-style pretraining.
class Wav2Vec2Bert:
def __init__(self, model_path: str, device: str = "auto",
device_index: int = 0, compute_type: str = "default",
inter_threads: int = 1, intra_threads: int = 0,
max_queued_batches: int = 0, files: dict = None):
"""
Initialize Wav2Vec2Bert model for enhanced speech processing.
Args:
model_path (str): Path to the CTranslate2 Wav2Vec2Bert model directory
device (str): Device to run on ("cpu", "cuda", "auto")
device_index (int): Device index for multi-GPU setups
compute_type (str): Computation precision ("default", "float32", "float16", "int8")
inter_threads (int): Number of inter-op threads
intra_threads (int): Number of intra-op threads (0 for auto)
max_queued_batches (int): Maximum number of batches in queue
files (dict): Additional model files mapping
"""
def encode(self, features: list, normalize: bool = False,
return_hidden: bool = False, **kwargs) -> list:
"""
Encode audio features using Wav2Vec2Bert.
Args:
features (list): List of raw audio waveforms or features
normalize (bool): Whether to normalize output representations
return_hidden (bool): Whether to return hidden states
**kwargs: Additional encoding parameters
Returns:
list: List of encoded representations
"""
def forward_batch(self, inputs: list, **kwargs) -> list:
"""
Forward pass on a batch of audio inputs.
Args:
inputs (list): List of audio input sequences
**kwargs: Additional forward pass parameters
Returns:
list: List of forward pass outputs
"""import ctranslate2
# Load Whisper model
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model", device="cpu")
# Prepare audio features (mel-spectrograms)
# Features should be mel-spectrograms with shape (80, time_steps)
audio_features = [mel_spectrogram_1, mel_spectrogram_2] # List of numpy arrays
# Transcribe audio
results = whisper.transcribe(audio_features, language="en", task="transcribe")
for result in results:
print("Transcription:", result.sequences[0])
if hasattr(result, 'timestamps') and result.timestamps:
print("Timestamps:", result.timestamps)
# Transcribe with word-level timestamps
results = whisper.transcribe(
audio_features,
language="en",
word_timestamps=True,
without_timestamps=False
)
for result in results:
print("Text:", result.sequences[0])
for word_info in result.word_timestamps:
print(f"Word: {word_info['word']}, Start: {word_info['start']:.2f}s, End: {word_info['end']:.2f}s")import ctranslate2
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
# Detect language from audio
language_results = whisper.detect_language(audio_features)
for result in language_results:
detected_language = result.language
confidence = result.language_probability
print(f"Detected language: {detected_language} (confidence: {confidence:.3f})")import ctranslate2
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
# Translate foreign speech to English
results = whisper.transcribe(
audio_features,
task="translate", # Translate to English
language="fr" # Source language is French
)
for result in results:
print("English translation:", result.sequences[0])import ctranslate2
import numpy as np
# Load Wav2Vec2 model
wav2vec2 = ctranslate2.models.Wav2Vec2("path/to/wav2vec2_ct2_model", device="cpu")
# Prepare raw audio waveforms
# Audio should be 16kHz mono waveforms
audio_waveforms = [waveform_1, waveform_2] # List of numpy arrays
# Extract speech representations
representations = wav2vec2.encode(audio_waveforms, normalize=True)
for repr in representations:
print("Representation shape:", repr.shape)
# Use representations for downstream tasks like speaker recognition,
# emotion detection, or as features for other modelsimport ctranslate2
# Load Wav2Vec2Bert model
wav2vec2bert = ctranslate2.models.Wav2Vec2Bert("path/to/wav2vec2bert_ct2_model")
# Extract enhanced representations
enhanced_representations = wav2vec2bert.encode(
audio_waveforms,
normalize=True,
return_hidden=True
)
for repr in enhanced_representations:
print("Enhanced representation shape:", repr.shape)
# These representations combine speech and language understandingimport ctranslate2
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model", device="cuda")
# Process multiple audio files efficiently
batch_features = [features_1, features_2, features_3, features_4]
# Batch transcription
batch_results = whisper.transcribe(
batch_features,
language="en",
beam_size=5,
temperature=0.0 # Deterministic output
)
for i, result in enumerate(batch_results):
print(f"Audio {i+1}: {result.sequences[0]}")import ctranslate2
whisper = ctranslate2.models.Whisper("path/to/whisper_ct2_model")
# Advanced transcription with custom parameters
results = whisper.transcribe(
audio_features,
language="en",
task="transcribe",
beam_size=10,
temperature=0.2,
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
initial_prompt="This is a technical presentation about machine learning.",
suppress_tokens=[50256, 50257], # Suppress specific tokens
word_timestamps=True,
vad_filter=True,
vad_parameters={
"threshold": 0.5,
"min_speech_duration_ms": 250,
"max_speech_duration_s": 30
}
)class WhisperGenerationResult:
"""Result from Whisper transcription/translation."""
sequences: list[list[str]] # Generated text sequences
scores: list[float] # Generation scores
language: str # Detected/specified language
language_probability: float # Language detection confidence
timestamps: list[dict] # Segment-level timestamps
word_timestamps: list[dict] # Word-level timestamps (if requested)
avg_logprob: float # Average log probability
compression_ratio: float # Compression ratio metric
no_speech_prob: float # No speech probability
class WhisperGenerationResultAsync:
"""Async result wrapper for Whisper operations."""
def result(self) -> WhisperGenerationResult: ...
def is_done(self) -> bool: ...
# Whisper-specific configuration structures
class WhisperTimestamp:
"""Word or segment timestamp information."""
start: float # Start time in seconds
end: float # End time in seconds
word: str # Word text (for word timestamps)
probability: float # Confidence score
class WhisperSegment:
"""Transcription segment with metadata."""
text: str # Segment text
start: float # Start time in seconds
end: float # End time in seconds
tokens: list[int] # Token IDs
temperature: float # Generation temperature used
avg_logprob: float # Average log probability
compression_ratio: float # Compression ratio
no_speech_prob: float # No speech probability
# Wav2Vec2 result types
class Wav2Vec2Output:
"""Output from Wav2Vec2 encoding."""
representations: StorageView # Learned speech representations
hidden_states: list # Hidden states (if requested)
attention_weights: list # Attention weights (if available)
class Wav2Vec2BertOutput:
"""Output from Wav2Vec2Bert encoding."""
representations: StorageView # Enhanced speech representations
hidden_states: list # Hidden states from all layers
attention_weights: list # Attention weights
adapter_outputs: list # Adapter layer outputsInstall with Tessl CLI
npx tessl i tessl/pypi-ctranslate2