0
# Pre-trained Models
1
2
Ready-to-use neural network models for speech recognition, synthesis, and source separation. TorchAudio provides implementations of state-of-the-art models along with factory functions for creating pre-trained instances.
3
4
## Capabilities
5
6
### Speech Recognition Models
7
8
Neural networks for automatic speech recognition and speech representation learning.
9
10
```python { .api }
11
class Wav2Vec2Model(torch.nn.Module):
12
"""Wav2Vec2 model for speech representation learning."""
13
14
def __init__(self, feature_extractor: torch.nn.Module, encoder: torch.nn.Module,
15
aux: Optional[torch.nn.Module] = None) -> None:
16
"""
17
Args:
18
feature_extractor: CNN feature extractor
19
encoder: Transformer encoder
20
aux: Auxiliary output layer (for fine-tuned models)
21
"""
22
23
def forward(self, waveforms: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Wav2Vec2ModelOutput:
24
"""
25
Args:
26
waveforms: Input audio (..., time)
27
lengths: Length of each sequence in batch
28
29
Returns:
30
Wav2Vec2ModelOutput with last_hidden_state, extract_features, etc.
31
"""
32
33
def wav2vec2_model(arch: str, num_out: Optional[int] = None) -> Wav2Vec2Model:
34
"""Create Wav2Vec2 model with specified architecture."""
35
36
def wav2vec2_base(num_out: Optional[int] = None) -> Wav2Vec2Model:
37
"""Create base Wav2Vec2 model (12 layers, 768 dim)."""
38
39
def wav2vec2_large(num_out: Optional[int] = None) -> Wav2Vec2Model:
40
"""Create large Wav2Vec2 model (24 layers, 1024 dim)."""
41
42
def wav2vec2_large_lv60k(num_out: Optional[int] = None) -> Wav2Vec2Model:
43
"""Create large Wav2Vec2 model pre-trained on Libri-Light."""
44
45
def wav2vec2_xlsr_300m(num_out: Optional[int] = None) -> Wav2Vec2Model:
46
"""Create XLSR-53 300M parameter multilingual model."""
47
48
def wav2vec2_xlsr_1b(num_out: Optional[int] = None) -> Wav2Vec2Model:
49
"""Create XLSR-53 1B parameter multilingual model."""
50
51
def wav2vec2_xlsr_2b(num_out: Optional[int] = None) -> Wav2Vec2Model:
52
"""Create XLSR-53 2B parameter multilingual model."""
53
54
class HuBERTPretrainModel(torch.nn.Module):
55
"""HuBERT model for self-supervised speech representation learning."""
56
57
def __init__(self, feature_extractor: torch.nn.Module, encoder: torch.nn.Module,
58
final_proj: torch.nn.Module, label_embs_concat: torch.nn.Module,
59
mask_generator: torch.nn.Module, logit_temp: float) -> None:
60
"""
61
Args:
62
feature_extractor: CNN feature extractor
63
encoder: Transformer encoder
64
final_proj: Final projection layer
65
label_embs_concat: Label embedding concatenation
66
mask_generator: Mask generator for pre-training
67
logit_temp: Temperature for logits
68
"""
69
70
def forward(self, waveforms: torch.Tensor, labels: Optional[torch.Tensor] = None,
71
audio_lengths: Optional[torch.Tensor] = None) -> HuBERTPretrainModelOutput:
72
"""
73
Args:
74
waveforms: Input audio (..., time)
75
labels: Target labels for pre-training
76
audio_lengths: Length of each sequence
77
78
Returns:
79
HuBERTPretrainModelOutput with logits, features, etc.
80
"""
81
82
def hubert_base(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
83
"""Create base HuBERT model."""
84
85
def hubert_large(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
86
"""Create large HuBERT model."""
87
88
def hubert_xlarge(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
89
"""Create extra-large HuBERT model."""
90
91
def hubert_pretrain_model(arch: str, aux_num_out: Optional[int] = None) -> HuBERTPretrainModel:
92
"""Create HuBERT pre-training model."""
93
94
def wavlm_model(arch: str, aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
95
"""Create WavLM model with specified architecture."""
96
97
def wavlm_base(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
98
"""Create base WavLM model."""
99
100
def wavlm_large(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:
101
"""Create large WavLM model."""
102
```
103
104
### Legacy Speech Recognition Models
105
106
Traditional neural network architectures for speech recognition.
107
108
```python { .api }
109
class DeepSpeech(torch.nn.Module):
110
"""DeepSpeech model for end-to-end speech recognition."""
111
112
def __init__(self, n_hidden: int, n_class: int) -> None:
113
"""
114
Args:
115
n_hidden: Number of hidden units in RNN layers
116
n_class: Number of output classes (characters/phonemes)
117
"""
118
119
def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
120
"""
121
Args:
122
x: Input features (..., freq, time)
123
lengths: Length of each sequence
124
125
Returns:
126
Tensor: Logits over character classes (..., time, n_class)
127
"""
128
129
class Wav2Letter(torch.nn.Module):
130
"""Wav2Letter model for speech recognition."""
131
132
def __init__(self, num_classes: int, input_type: str = "waveform",
133
num_features: Optional[int] = None, num_hidden: int = 1000) -> None:
134
"""
135
Args:
136
num_classes: Number of output classes
137
input_type: Type of input ("waveform" or "features")
138
num_features: Number of input features (required if input_type="features")
139
num_hidden: Number of hidden units
140
"""
141
142
def forward(self, x: torch.Tensor) -> torch.Tensor:
143
"""
144
Args:
145
x: Input tensor (waveform or features)
146
147
Returns:
148
Tensor: Class probabilities
149
"""
150
```
151
152
### RNN-Transducer Models
153
154
Neural transducer models for streaming speech recognition.
155
156
```python { .api }
157
class RNNT(torch.nn.Module):
158
"""RNN-Transducer model for streaming speech recognition."""
159
160
def __init__(self, transcriber: torch.nn.Module, predictor: torch.nn.Module,
161
joiner: torch.nn.Module) -> None:
162
"""
163
Args:
164
transcriber: Encoder network (processes audio features)
165
predictor: Decoder network (processes previous predictions)
166
joiner: Joint network (combines encoder and decoder outputs)
167
"""
168
169
def forward(self, sources: torch.Tensor, source_lengths: torch.Tensor,
170
targets: torch.Tensor, target_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
"""
172
Args:
173
sources: Input audio features (batch, time, feature_dim)
174
source_lengths: Length of each audio sequence
175
targets: Target token sequences (batch, target_time)
176
target_lengths: Length of each target sequence
177
178
Returns:
179
Tuple of (transcriber_out, predictor_out, joiner_out)
180
"""
181
182
class Conformer(torch.nn.Module):
183
"""Conformer model combining CNN and self-attention."""
184
185
def __init__(self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int,
186
depthwise_conv_kernel_size: int = 31, dropout: float = 0.1,
187
use_group_norm: bool = False, convolution_first: bool = False) -> None:
188
"""
189
Args:
190
input_dim: Input feature dimension
191
num_heads: Number of attention heads
192
ffn_dim: Feed-forward network dimension
193
num_layers: Number of conformer layers
194
depthwise_conv_kernel_size: Kernel size for depthwise convolution
195
dropout: Dropout probability
196
use_group_norm: Whether to use group normalization
197
convolution_first: Whether to apply convolution before self-attention
198
"""
199
200
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
201
"""
202
Args:
203
input: Input features (batch, time, feature_dim)
204
lengths: Length of each sequence
205
206
Returns:
207
Tuple of (output, output_lengths)
208
"""
209
210
class Emformer(torch.nn.Module):
211
"""Emformer model for streaming applications."""
212
213
def __init__(self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int,
214
segment_length: int, left_context_length: int = 0,
215
right_context_length: int = 0, max_memory_size: int = 0,
216
weight_init_scale_strategy: str = "depthwise", tanh_on_mem: bool = False,
217
negative_inf: float = -1e8) -> None:
218
"""
219
Args:
220
input_dim: Input feature dimension
221
num_heads: Number of attention heads
222
ffn_dim: Feed-forward dimension
223
num_layers: Number of layers
224
segment_length: Length of each segment
225
left_context_length: Left context length
226
right_context_length: Right context length
227
max_memory_size: Maximum memory size
228
weight_init_scale_strategy: Weight initialization strategy
229
tanh_on_mem: Whether to apply tanh on memory
230
negative_inf: Negative infinity value for masking
231
"""
232
233
def forward(self, input: torch.Tensor, lengths: torch.Tensor,
234
mems: Optional[List[List[torch.Tensor]]] = None) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
235
"""
236
Args:
237
input: Input features (batch, time, feature_dim)
238
lengths: Length of each sequence
239
mems: Previous memory states
240
241
Returns:
242
Tuple of (output, output_lengths, new_mems)
243
"""
244
245
def emformer_rnnt_base(num_symbols: int) -> RNNT:
246
"""Create base Emformer RNN-T model."""
247
248
def emformer_rnnt_model(arch: str, num_symbols: int) -> RNNT:
249
"""Create Emformer RNN-T model with specified architecture."""
250
```
251
252
### Speech Synthesis Models
253
254
Neural networks for text-to-speech synthesis and vocoding.
255
256
```python { .api }
257
class Tacotron2(torch.nn.Module):
258
"""Tacotron2 model for text-to-speech synthesis."""
259
260
def __init__(self, mask_padding: bool = False, n_mels: int = 80,
261
n_frames_per_step: int = 1, n_characters: int = 188,
262
n_hidden: int = 1024, p_attention_dropout: float = 0.1,
263
p_decoder_dropout: float = 0.1, prenet_dim: int = 256,
264
postnet_embedding_dim: int = 512, postnet_kernel_size: int = 5,
265
postnet_n_convolutions: int = 5, postnet_dropout: float = 0.5,
266
attention_rnn_dim: int = 1024, attention_dim: int = 128,
267
attention_location_n_filters: int = 32, attention_location_kernel_size: int = 31,
268
encoder_embedding_dim: int = 512, encoder_n_convolutions: int = 3,
269
encoder_kernel_size: int = 5, encoder_dropout: float = 0.5,
270
decoder_rnn_dim: int = 1024, decoder_max_step: int = 2000,
271
gate_threshold: float = 0.5, p_teacher_forcing: float = 1.0,
272
decoder_dropout: float = 0.1, memory_dropout: float = 0.1) -> None:
273
"""
274
Args:
275
mask_padding: Whether to mask padding in loss computation
276
n_mels: Number of mel frequency bins
277
n_frames_per_step: Number of frames generated per step
278
(additional parameters for model architecture configuration)
279
"""
280
281
def forward(self, tokens: torch.Tensor, token_lengths: torch.Tensor,
282
mel_specgram: Optional[torch.Tensor] = None,
283
mel_specgram_lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
284
"""
285
Args:
286
tokens: Input token sequences (batch, max_token_length)
287
token_lengths: Length of each token sequence
288
mel_specgram: Target mel spectrograms (for training)
289
mel_specgram_lengths: Length of each mel spectrogram
290
291
Returns:
292
Tuple of (mel_outputs, mel_outputs_postnet, gate_outputs)
293
"""
294
295
class WaveRNN(torch.nn.Module):
296
"""WaveRNN vocoder for high-quality audio generation."""
297
298
def __init__(self, upsample_scales: List[int], n_classes: int, hop_length: int,
299
n_res_block: int = 10, n_rnn: int = 512, n_fc: int = 512,
300
kernel_size: int = 5, n_freq: int = 128, padding: int = 2) -> None:
301
"""
302
Args:
303
upsample_scales: Upsampling scales for each layer
304
n_classes: Number of output classes (for mu-law quantization)
305
hop_length: Hop length for upsampling
306
n_res_block: Number of residual blocks
307
n_rnn: RNN hidden dimension
308
n_fc: Fully connected layer dimension
309
kernel_size: Convolution kernel size
310
n_freq: Number of frequency bins
311
padding: Convolution padding
312
"""
313
314
def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
315
"""
316
Args:
317
x: Input audio sequence (batch, time)
318
mels: Mel spectrogram conditioning (batch, freq, time)
319
320
Returns:
321
Tensor: Output logits (batch, time, n_classes)
322
"""
323
```
324
325
### Source Separation Models
326
327
Neural networks for separating mixed audio into individual sources.
328
329
```python { .api }
330
class ConvTasNet(torch.nn.Module):
331
"""Convolutional Time-domain Audio Source Separation Network."""
332
333
def __init__(self, num_sources: int = 2, enc_kernel_size: int = 16,
334
enc_num_feats: int = 512, msk_kernel_size: int = 3,
335
msk_num_feats: int = 128, msk_num_hidden_feats: int = 512,
336
msk_num_layers: int = 8, msk_num_stacks: int = 3,
337
msk_activate: str = "sigmoid") -> None:
338
"""
339
Args:
340
num_sources: Number of sources to separate
341
enc_kernel_size: Encoder kernel size
342
enc_num_feats: Number of encoder features
343
msk_kernel_size: Mask generator kernel size
344
msk_num_feats: Number of mask features
345
msk_num_hidden_feats: Number of hidden features in mask generator
346
msk_num_layers: Number of layers in each stack
347
msk_num_stacks: Number of stacks
348
msk_activate: Activation function for masks
349
"""
350
351
def forward(self, input: torch.Tensor) -> torch.Tensor:
352
"""
353
Args:
354
input: Mixed audio waveform (batch, time)
355
356
Returns:
357
Tensor: Separated sources (batch, num_sources, time)
358
"""
359
360
def conv_tasnet_base(num_sources: int) -> ConvTasNet:
361
"""Create base ConvTasNet model."""
362
363
class HDemucs(torch.nn.Module):
364
"""Hybrid Demucs model for music source separation."""
365
366
def __init__(self, sources: List[str], audio_channels: int = 2, channels: int = 48,
367
growth: float = 2.0, nfft: int = 4096, wiener_iters: int = 0,
368
end_iters: int = 0, wiener_residual: bool = False, cac: bool = True,
369
depth: int = 6, rewrite: bool = True, hybrid: bool = True,
370
hybrid_old: bool = False, multi_freqs: List[int] = None,
371
multi_freqs_depth: int = 2, freq_emb: Optional[int] = None,
372
emb_scale: int = 10, emb_smooth: bool = False,
373
kernel_size: int = 8, time_stride: int = 2, stride: int = 4,
374
context: int = 1, context_enc: int = 0, norm_starts: int = 4,
375
norm_groups: int = 4, dconv_mode: int = 1, dconv_depth: int = 2,
376
dconv_comp: int = 4, dconv_attn: int = 4, dconv_lstm: int = 4,
377
dconv_init: float = 1e-4, bottom_channels: int = 0,
378
clone_kw: Dict[str, Any] = None, num_subbands: int = 1,
379
spec_complex: bool = True, segment_length: int = 4 * 10 * 44100) -> None:
380
"""
381
Args:
382
sources: List of source names to separate
383
audio_channels: Number of audio channels
384
channels: Base number of channels
385
growth: Channel growth factor per layer
386
nfft: FFT size for spectral branch
387
wiener_iters: Number of Wiener filtering iterations
388
(additional parameters for model configuration)
389
"""
390
391
def forward(self, wav: torch.Tensor) -> torch.Tensor:
392
"""
393
Args:
394
wav: Input audio (batch, channels, time)
395
396
Returns:
397
Tensor: Separated sources (batch, sources, channels, time)
398
"""
399
400
def hdemucs_low() -> HDemucs:
401
"""Create low-complexity HDemucs model."""
402
403
def hdemucs_medium() -> HDemucs:
404
"""Create medium HDemucs model."""
405
406
def hdemucs_high() -> HDemucs:
407
"""Create high-quality HDemucs model."""
408
```
409
410
### Speech Quality Assessment Models
411
412
Models for objective and subjective speech quality assessment.
413
414
```python { .api }
415
class SquimObjective(torch.nn.Module):
416
"""SQUIM model for objective speech quality assessment."""
417
418
def __init__(self, encoder: torch.nn.Module, classifier: torch.nn.Module) -> None:
419
"""
420
Args:
421
encoder: Feature encoder network
422
classifier: Quality prediction classifier
423
"""
424
425
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
426
"""
427
Args:
428
waveforms: Input audio (batch, time)
429
430
Returns:
431
Tensor: Quality scores (STOI, PESQ, SI-SDR)
432
"""
433
434
class SquimSubjective(torch.nn.Module):
435
"""SQUIM model for subjective speech quality assessment."""
436
437
def __init__(self, encoder: torch.nn.Module, classifier: torch.nn.Module) -> None:
438
"""
439
Args:
440
encoder: Feature encoder network
441
classifier: Quality prediction classifier
442
"""
443
444
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
445
"""
446
Args:
447
waveforms: Input audio (batch, time)
448
449
Returns:
450
Tensor: Subjective quality scores (MOS)
451
"""
452
453
def squim_objective_base() -> SquimObjective:
454
"""Create base SQUIM objective model."""
455
456
def squim_objective_model() -> SquimObjective:
457
"""Create SQUIM objective model."""
458
459
def squim_subjective_base() -> SquimSubjective:
460
"""Create base SQUIM subjective model."""
461
462
def squim_subjective_model() -> SquimSubjective:
463
"""Create SQUIM subjective model."""
464
```
465
466
### Decoder Utilities
467
468
Utilities for decoding model outputs, particularly for sequence-to-sequence models.
469
470
```python { .api }
471
class RNNTBeamSearch(torch.nn.Module):
472
"""Beam search decoder for RNN-Transducer models."""
473
474
def __init__(self, model: RNNT, blank: int, temperature: float = 1.0,
475
hyp_sort_score: Optional[Callable] = None,
476
token_sort_score: Optional[Callable] = None) -> None:
477
"""
478
Args:
479
model: RNN-T model to decode
480
blank: Blank token index
481
temperature: Temperature for softmax
482
hyp_sort_score: Function to score hypotheses
483
token_sort_score: Function to score tokens
484
"""
485
486
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int,
487
max_symbol_per_frame: Optional[int] = None) -> List[List[Hypothesis]]:
488
"""
489
Args:
490
input: Input features (batch, time, feature_dim)
491
length: Length of each sequence
492
beam_width: Beam search width
493
max_symbol_per_frame: Maximum symbols per frame
494
495
Returns:
496
List of hypotheses for each batch item
497
"""
498
499
class Hypothesis:
500
"""Hypothesis object for beam search."""
501
502
def __init__(self, score: float, y_sequence: List[int], dec_state: List[List[torch.Tensor]],
503
lm_state: Optional[Any] = None, lm_score: Optional[torch.Tensor] = None,
504
tokens: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,
505
last_token: Optional[int] = None) -> None:
506
"""
507
Args:
508
score: Hypothesis score
509
y_sequence: Sequence of predicted tokens
510
dec_state: Decoder state
511
lm_state: Language model state
512
lm_score: Language model score
513
tokens: Token probabilities
514
timestep: Current timestep
515
last_token: Last predicted token
516
"""
517
518
score: float
519
y_sequence: List[int]
520
dec_state: List[List[torch.Tensor]]
521
lm_state: Optional[Any]
522
lm_score: Optional[torch.Tensor]
523
tokens: Optional[torch.Tensor]
524
timestep: Optional[torch.Tensor]
525
last_token: Optional[int]
526
```
527
528
Usage example:
529
530
```python
531
import torch
532
import torchaudio
533
from torchaudio.models import wav2vec2_base, Tacotron2
534
535
# Load pre-trained Wav2Vec2 model
536
model = wav2vec2_base(num_out=32) # 32 output classes for character recognition
537
model.eval()
538
539
# Process audio with Wav2Vec2
540
waveform, sample_rate = torchaudio.load("speech.wav")
541
with torch.no_grad():
542
features, lengths = model(waveform) # Extract features
543
logits = model.aux(features) # Get classification logits
544
545
# Create Tacotron2 for TTS
546
tts_model = Tacotron2()
547
tts_model.eval()
548
549
# Synthesize speech (tokens would come from text processing)
550
tokens = torch.randint(0, 188, (1, 50)) # Random tokens for example
551
token_lengths = torch.tensor([50])
552
553
with torch.no_grad():
554
mel_outputs, mel_outputs_postnet, gate_outputs = tts_model(tokens, token_lengths)
555
```
556
557
These models provide state-of-the-art capabilities for various audio processing tasks and can be used as building blocks for more complex applications.