or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio-io.mddatasets.mdeffects.mdfunctional.mdindex.mdmodels.mdpipelines.mdstreaming.mdtransforms.mdutils.md

models.mddocs/

0

# Pre-trained Models

1

2

Ready-to-use neural network models for speech recognition, synthesis, and source separation. TorchAudio provides implementations of state-of-the-art models along with factory functions for creating pre-trained instances.

3

4

## Capabilities

5

6

### Speech Recognition Models

7

8

Neural networks for automatic speech recognition and speech representation learning.

9

10

```python { .api }

11

class Wav2Vec2Model(torch.nn.Module):

12

"""Wav2Vec2 model for speech representation learning."""

13

14

def __init__(self, feature_extractor: torch.nn.Module, encoder: torch.nn.Module,

15

aux: Optional[torch.nn.Module] = None) -> None:

16

"""

17

Args:

18

feature_extractor: CNN feature extractor

19

encoder: Transformer encoder

20

aux: Auxiliary output layer (for fine-tuned models)

21

"""

22

23

def forward(self, waveforms: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Wav2Vec2ModelOutput:

24

"""

25

Args:

26

waveforms: Input audio (..., time)

27

lengths: Length of each sequence in batch

28

29

Returns:

30

Wav2Vec2ModelOutput with last_hidden_state, extract_features, etc.

31

"""

32

33

def wav2vec2_model(arch: str, num_out: Optional[int] = None) -> Wav2Vec2Model:

34

"""Create Wav2Vec2 model with specified architecture."""

35

36

def wav2vec2_base(num_out: Optional[int] = None) -> Wav2Vec2Model:

37

"""Create base Wav2Vec2 model (12 layers, 768 dim)."""

38

39

def wav2vec2_large(num_out: Optional[int] = None) -> Wav2Vec2Model:

40

"""Create large Wav2Vec2 model (24 layers, 1024 dim)."""

41

42

def wav2vec2_large_lv60k(num_out: Optional[int] = None) -> Wav2Vec2Model:

43

"""Create large Wav2Vec2 model pre-trained on Libri-Light."""

44

45

def wav2vec2_xlsr_300m(num_out: Optional[int] = None) -> Wav2Vec2Model:

46

"""Create XLSR-53 300M parameter multilingual model."""

47

48

def wav2vec2_xlsr_1b(num_out: Optional[int] = None) -> Wav2Vec2Model:

49

"""Create XLSR-53 1B parameter multilingual model."""

50

51

def wav2vec2_xlsr_2b(num_out: Optional[int] = None) -> Wav2Vec2Model:

52

"""Create XLSR-53 2B parameter multilingual model."""

53

54

class HuBERTPretrainModel(torch.nn.Module):

55

"""HuBERT model for self-supervised speech representation learning."""

56

57

def __init__(self, feature_extractor: torch.nn.Module, encoder: torch.nn.Module,

58

final_proj: torch.nn.Module, label_embs_concat: torch.nn.Module,

59

mask_generator: torch.nn.Module, logit_temp: float) -> None:

60

"""

61

Args:

62

feature_extractor: CNN feature extractor

63

encoder: Transformer encoder

64

final_proj: Final projection layer

65

label_embs_concat: Label embedding concatenation

66

mask_generator: Mask generator for pre-training

67

logit_temp: Temperature for logits

68

"""

69

70

def forward(self, waveforms: torch.Tensor, labels: Optional[torch.Tensor] = None,

71

audio_lengths: Optional[torch.Tensor] = None) -> HuBERTPretrainModelOutput:

72

"""

73

Args:

74

waveforms: Input audio (..., time)

75

labels: Target labels for pre-training

76

audio_lengths: Length of each sequence

77

78

Returns:

79

HuBERTPretrainModelOutput with logits, features, etc.

80

"""

81

82

def hubert_base(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

83

"""Create base HuBERT model."""

84

85

def hubert_large(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

86

"""Create large HuBERT model."""

87

88

def hubert_xlarge(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

89

"""Create extra-large HuBERT model."""

90

91

def hubert_pretrain_model(arch: str, aux_num_out: Optional[int] = None) -> HuBERTPretrainModel:

92

"""Create HuBERT pre-training model."""

93

94

def wavlm_model(arch: str, aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

95

"""Create WavLM model with specified architecture."""

96

97

def wavlm_base(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

98

"""Create base WavLM model."""

99

100

def wavlm_large(aux_num_out: Optional[int] = None) -> Wav2Vec2Model:

101

"""Create large WavLM model."""

102

```

103

104

### Legacy Speech Recognition Models

105

106

Traditional neural network architectures for speech recognition.

107

108

```python { .api }

109

class DeepSpeech(torch.nn.Module):

110

"""DeepSpeech model for end-to-end speech recognition."""

111

112

def __init__(self, n_hidden: int, n_class: int) -> None:

113

"""

114

Args:

115

n_hidden: Number of hidden units in RNN layers

116

n_class: Number of output classes (characters/phonemes)

117

"""

118

119

def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:

120

"""

121

Args:

122

x: Input features (..., freq, time)

123

lengths: Length of each sequence

124

125

Returns:

126

Tensor: Logits over character classes (..., time, n_class)

127

"""

128

129

class Wav2Letter(torch.nn.Module):

130

"""Wav2Letter model for speech recognition."""

131

132

def __init__(self, num_classes: int, input_type: str = "waveform",

133

num_features: Optional[int] = None, num_hidden: int = 1000) -> None:

134

"""

135

Args:

136

num_classes: Number of output classes

137

input_type: Type of input ("waveform" or "features")

138

num_features: Number of input features (required if input_type="features")

139

num_hidden: Number of hidden units

140

"""

141

142

def forward(self, x: torch.Tensor) -> torch.Tensor:

143

"""

144

Args:

145

x: Input tensor (waveform or features)

146

147

Returns:

148

Tensor: Class probabilities

149

"""

150

```

151

152

### RNN-Transducer Models

153

154

Neural transducer models for streaming speech recognition.

155

156

```python { .api }

157

class RNNT(torch.nn.Module):

158

"""RNN-Transducer model for streaming speech recognition."""

159

160

def __init__(self, transcriber: torch.nn.Module, predictor: torch.nn.Module,

161

joiner: torch.nn.Module) -> None:

162

"""

163

Args:

164

transcriber: Encoder network (processes audio features)

165

predictor: Decoder network (processes previous predictions)

166

joiner: Joint network (combines encoder and decoder outputs)

167

"""

168

169

def forward(self, sources: torch.Tensor, source_lengths: torch.Tensor,

170

targets: torch.Tensor, target_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

171

"""

172

Args:

173

sources: Input audio features (batch, time, feature_dim)

174

source_lengths: Length of each audio sequence

175

targets: Target token sequences (batch, target_time)

176

target_lengths: Length of each target sequence

177

178

Returns:

179

Tuple of (transcriber_out, predictor_out, joiner_out)

180

"""

181

182

class Conformer(torch.nn.Module):

183

"""Conformer model combining CNN and self-attention."""

184

185

def __init__(self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int,

186

depthwise_conv_kernel_size: int = 31, dropout: float = 0.1,

187

use_group_norm: bool = False, convolution_first: bool = False) -> None:

188

"""

189

Args:

190

input_dim: Input feature dimension

191

num_heads: Number of attention heads

192

ffn_dim: Feed-forward network dimension

193

num_layers: Number of conformer layers

194

depthwise_conv_kernel_size: Kernel size for depthwise convolution

195

dropout: Dropout probability

196

use_group_norm: Whether to use group normalization

197

convolution_first: Whether to apply convolution before self-attention

198

"""

199

200

def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

201

"""

202

Args:

203

input: Input features (batch, time, feature_dim)

204

lengths: Length of each sequence

205

206

Returns:

207

Tuple of (output, output_lengths)

208

"""

209

210

class Emformer(torch.nn.Module):

211

"""Emformer model for streaming applications."""

212

213

def __init__(self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int,

214

segment_length: int, left_context_length: int = 0,

215

right_context_length: int = 0, max_memory_size: int = 0,

216

weight_init_scale_strategy: str = "depthwise", tanh_on_mem: bool = False,

217

negative_inf: float = -1e8) -> None:

218

"""

219

Args:

220

input_dim: Input feature dimension

221

num_heads: Number of attention heads

222

ffn_dim: Feed-forward dimension

223

num_layers: Number of layers

224

segment_length: Length of each segment

225

left_context_length: Left context length

226

right_context_length: Right context length

227

max_memory_size: Maximum memory size

228

weight_init_scale_strategy: Weight initialization strategy

229

tanh_on_mem: Whether to apply tanh on memory

230

negative_inf: Negative infinity value for masking

231

"""

232

233

def forward(self, input: torch.Tensor, lengths: torch.Tensor,

234

mems: Optional[List[List[torch.Tensor]]] = None) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:

235

"""

236

Args:

237

input: Input features (batch, time, feature_dim)

238

lengths: Length of each sequence

239

mems: Previous memory states

240

241

Returns:

242

Tuple of (output, output_lengths, new_mems)

243

"""

244

245

def emformer_rnnt_base(num_symbols: int) -> RNNT:

246

"""Create base Emformer RNN-T model."""

247

248

def emformer_rnnt_model(arch: str, num_symbols: int) -> RNNT:

249

"""Create Emformer RNN-T model with specified architecture."""

250

```

251

252

### Speech Synthesis Models

253

254

Neural networks for text-to-speech synthesis and vocoding.

255

256

```python { .api }

257

class Tacotron2(torch.nn.Module):

258

"""Tacotron2 model for text-to-speech synthesis."""

259

260

def __init__(self, mask_padding: bool = False, n_mels: int = 80,

261

n_frames_per_step: int = 1, n_characters: int = 188,

262

n_hidden: int = 1024, p_attention_dropout: float = 0.1,

263

p_decoder_dropout: float = 0.1, prenet_dim: int = 256,

264

postnet_embedding_dim: int = 512, postnet_kernel_size: int = 5,

265

postnet_n_convolutions: int = 5, postnet_dropout: float = 0.5,

266

attention_rnn_dim: int = 1024, attention_dim: int = 128,

267

attention_location_n_filters: int = 32, attention_location_kernel_size: int = 31,

268

encoder_embedding_dim: int = 512, encoder_n_convolutions: int = 3,

269

encoder_kernel_size: int = 5, encoder_dropout: float = 0.5,

270

decoder_rnn_dim: int = 1024, decoder_max_step: int = 2000,

271

gate_threshold: float = 0.5, p_teacher_forcing: float = 1.0,

272

decoder_dropout: float = 0.1, memory_dropout: float = 0.1) -> None:

273

"""

274

Args:

275

mask_padding: Whether to mask padding in loss computation

276

n_mels: Number of mel frequency bins

277

n_frames_per_step: Number of frames generated per step

278

(additional parameters for model architecture configuration)

279

"""

280

281

def forward(self, tokens: torch.Tensor, token_lengths: torch.Tensor,

282

mel_specgram: Optional[torch.Tensor] = None,

283

mel_specgram_lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

284

"""

285

Args:

286

tokens: Input token sequences (batch, max_token_length)

287

token_lengths: Length of each token sequence

288

mel_specgram: Target mel spectrograms (for training)

289

mel_specgram_lengths: Length of each mel spectrogram

290

291

Returns:

292

Tuple of (mel_outputs, mel_outputs_postnet, gate_outputs)

293

"""

294

295

class WaveRNN(torch.nn.Module):

296

"""WaveRNN vocoder for high-quality audio generation."""

297

298

def __init__(self, upsample_scales: List[int], n_classes: int, hop_length: int,

299

n_res_block: int = 10, n_rnn: int = 512, n_fc: int = 512,

300

kernel_size: int = 5, n_freq: int = 128, padding: int = 2) -> None:

301

"""

302

Args:

303

upsample_scales: Upsampling scales for each layer

304

n_classes: Number of output classes (for mu-law quantization)

305

hop_length: Hop length for upsampling

306

n_res_block: Number of residual blocks

307

n_rnn: RNN hidden dimension

308

n_fc: Fully connected layer dimension

309

kernel_size: Convolution kernel size

310

n_freq: Number of frequency bins

311

padding: Convolution padding

312

"""

313

314

def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:

315

"""

316

Args:

317

x: Input audio sequence (batch, time)

318

mels: Mel spectrogram conditioning (batch, freq, time)

319

320

Returns:

321

Tensor: Output logits (batch, time, n_classes)

322

"""

323

```

324

325

### Source Separation Models

326

327

Neural networks for separating mixed audio into individual sources.

328

329

```python { .api }

330

class ConvTasNet(torch.nn.Module):

331

"""Convolutional Time-domain Audio Source Separation Network."""

332

333

def __init__(self, num_sources: int = 2, enc_kernel_size: int = 16,

334

enc_num_feats: int = 512, msk_kernel_size: int = 3,

335

msk_num_feats: int = 128, msk_num_hidden_feats: int = 512,

336

msk_num_layers: int = 8, msk_num_stacks: int = 3,

337

msk_activate: str = "sigmoid") -> None:

338

"""

339

Args:

340

num_sources: Number of sources to separate

341

enc_kernel_size: Encoder kernel size

342

enc_num_feats: Number of encoder features

343

msk_kernel_size: Mask generator kernel size

344

msk_num_feats: Number of mask features

345

msk_num_hidden_feats: Number of hidden features in mask generator

346

msk_num_layers: Number of layers in each stack

347

msk_num_stacks: Number of stacks

348

msk_activate: Activation function for masks

349

"""

350

351

def forward(self, input: torch.Tensor) -> torch.Tensor:

352

"""

353

Args:

354

input: Mixed audio waveform (batch, time)

355

356

Returns:

357

Tensor: Separated sources (batch, num_sources, time)

358

"""

359

360

def conv_tasnet_base(num_sources: int) -> ConvTasNet:

361

"""Create base ConvTasNet model."""

362

363

class HDemucs(torch.nn.Module):

364

"""Hybrid Demucs model for music source separation."""

365

366

def __init__(self, sources: List[str], audio_channels: int = 2, channels: int = 48,

367

growth: float = 2.0, nfft: int = 4096, wiener_iters: int = 0,

368

end_iters: int = 0, wiener_residual: bool = False, cac: bool = True,

369

depth: int = 6, rewrite: bool = True, hybrid: bool = True,

370

hybrid_old: bool = False, multi_freqs: List[int] = None,

371

multi_freqs_depth: int = 2, freq_emb: Optional[int] = None,

372

emb_scale: int = 10, emb_smooth: bool = False,

373

kernel_size: int = 8, time_stride: int = 2, stride: int = 4,

374

context: int = 1, context_enc: int = 0, norm_starts: int = 4,

375

norm_groups: int = 4, dconv_mode: int = 1, dconv_depth: int = 2,

376

dconv_comp: int = 4, dconv_attn: int = 4, dconv_lstm: int = 4,

377

dconv_init: float = 1e-4, bottom_channels: int = 0,

378

clone_kw: Dict[str, Any] = None, num_subbands: int = 1,

379

spec_complex: bool = True, segment_length: int = 4 * 10 * 44100) -> None:

380

"""

381

Args:

382

sources: List of source names to separate

383

audio_channels: Number of audio channels

384

channels: Base number of channels

385

growth: Channel growth factor per layer

386

nfft: FFT size for spectral branch

387

wiener_iters: Number of Wiener filtering iterations

388

(additional parameters for model configuration)

389

"""

390

391

def forward(self, wav: torch.Tensor) -> torch.Tensor:

392

"""

393

Args:

394

wav: Input audio (batch, channels, time)

395

396

Returns:

397

Tensor: Separated sources (batch, sources, channels, time)

398

"""

399

400

def hdemucs_low() -> HDemucs:

401

"""Create low-complexity HDemucs model."""

402

403

def hdemucs_medium() -> HDemucs:

404

"""Create medium HDemucs model."""

405

406

def hdemucs_high() -> HDemucs:

407

"""Create high-quality HDemucs model."""

408

```

409

410

### Speech Quality Assessment Models

411

412

Models for objective and subjective speech quality assessment.

413

414

```python { .api }

415

class SquimObjective(torch.nn.Module):

416

"""SQUIM model for objective speech quality assessment."""

417

418

def __init__(self, encoder: torch.nn.Module, classifier: torch.nn.Module) -> None:

419

"""

420

Args:

421

encoder: Feature encoder network

422

classifier: Quality prediction classifier

423

"""

424

425

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:

426

"""

427

Args:

428

waveforms: Input audio (batch, time)

429

430

Returns:

431

Tensor: Quality scores (STOI, PESQ, SI-SDR)

432

"""

433

434

class SquimSubjective(torch.nn.Module):

435

"""SQUIM model for subjective speech quality assessment."""

436

437

def __init__(self, encoder: torch.nn.Module, classifier: torch.nn.Module) -> None:

438

"""

439

Args:

440

encoder: Feature encoder network

441

classifier: Quality prediction classifier

442

"""

443

444

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:

445

"""

446

Args:

447

waveforms: Input audio (batch, time)

448

449

Returns:

450

Tensor: Subjective quality scores (MOS)

451

"""

452

453

def squim_objective_base() -> SquimObjective:

454

"""Create base SQUIM objective model."""

455

456

def squim_objective_model() -> SquimObjective:

457

"""Create SQUIM objective model."""

458

459

def squim_subjective_base() -> SquimSubjective:

460

"""Create base SQUIM subjective model."""

461

462

def squim_subjective_model() -> SquimSubjective:

463

"""Create SQUIM subjective model."""

464

```

465

466

### Decoder Utilities

467

468

Utilities for decoding model outputs, particularly for sequence-to-sequence models.

469

470

```python { .api }

471

class RNNTBeamSearch(torch.nn.Module):

472

"""Beam search decoder for RNN-Transducer models."""

473

474

def __init__(self, model: RNNT, blank: int, temperature: float = 1.0,

475

hyp_sort_score: Optional[Callable] = None,

476

token_sort_score: Optional[Callable] = None) -> None:

477

"""

478

Args:

479

model: RNN-T model to decode

480

blank: Blank token index

481

temperature: Temperature for softmax

482

hyp_sort_score: Function to score hypotheses

483

token_sort_score: Function to score tokens

484

"""

485

486

def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int,

487

max_symbol_per_frame: Optional[int] = None) -> List[List[Hypothesis]]:

488

"""

489

Args:

490

input: Input features (batch, time, feature_dim)

491

length: Length of each sequence

492

beam_width: Beam search width

493

max_symbol_per_frame: Maximum symbols per frame

494

495

Returns:

496

List of hypotheses for each batch item

497

"""

498

499

class Hypothesis:

500

"""Hypothesis object for beam search."""

501

502

def __init__(self, score: float, y_sequence: List[int], dec_state: List[List[torch.Tensor]],

503

lm_state: Optional[Any] = None, lm_score: Optional[torch.Tensor] = None,

504

tokens: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,

505

last_token: Optional[int] = None) -> None:

506

"""

507

Args:

508

score: Hypothesis score

509

y_sequence: Sequence of predicted tokens

510

dec_state: Decoder state

511

lm_state: Language model state

512

lm_score: Language model score

513

tokens: Token probabilities

514

timestep: Current timestep

515

last_token: Last predicted token

516

"""

517

518

score: float

519

y_sequence: List[int]

520

dec_state: List[List[torch.Tensor]]

521

lm_state: Optional[Any]

522

lm_score: Optional[torch.Tensor]

523

tokens: Optional[torch.Tensor]

524

timestep: Optional[torch.Tensor]

525

last_token: Optional[int]

526

```

527

528

Usage example:

529

530

```python

531

import torch

532

import torchaudio

533

from torchaudio.models import wav2vec2_base, Tacotron2

534

535

# Load pre-trained Wav2Vec2 model

536

model = wav2vec2_base(num_out=32) # 32 output classes for character recognition

537

model.eval()

538

539

# Process audio with Wav2Vec2

540

waveform, sample_rate = torchaudio.load("speech.wav")

541

with torch.no_grad():

542

features, lengths = model(waveform) # Extract features

543

logits = model.aux(features) # Get classification logits

544

545

# Create Tacotron2 for TTS

546

tts_model = Tacotron2()

547

tts_model.eval()

548

549

# Synthesize speech (tokens would come from text processing)

550

tokens = torch.randint(0, 188, (1, 50)) # Random tokens for example

551

token_lengths = torch.tensor([50])

552

553

with torch.no_grad():

554

mel_outputs, mel_outputs_postnet, gate_outputs = tts_model(tokens, token_lengths)

555

```

556

557

These models provide state-of-the-art capabilities for various audio processing tasks and can be used as building blocks for more complex applications.