or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio-io.mddatasets.mdeffects.mdfunctional.mdindex.mdmodels.mdpipelines.mdstreaming.mdtransforms.mdutils.md

transforms.mddocs/

0

# Audio Transforms

1

2

PyTorch-compatible transform classes for building differentiable audio processing pipelines. These transforms are torch.nn.Module subclasses that can be composed with neural networks and trained end-to-end using automatic differentiation.

3

4

## Capabilities

5

6

### Spectral Transforms

7

8

Core spectral analysis transforms for converting between time and frequency domains.

9

10

```python { .api }

11

class Spectrogram(torch.nn.Module):

12

"""Compute spectrogram of audio signal."""

13

14

def __init__(self, n_fft: int = 400, win_length: Optional[int] = None,

15

hop_length: Optional[int] = None, pad: int = 0,

16

window_fn: Callable[..., torch.Tensor] = torch.hann_window,

17

power: Optional[float] = 2.0, normalized: bool = False,

18

wkwargs: Optional[Dict[str, Any]] = None, center: bool = True,

19

pad_mode: str = "reflect", onesided: bool = True) -> None:

20

"""

21

Args:

22

n_fft: Size of FFT

23

win_length: Window size (defaults to n_fft)

24

hop_length: Length of hop between STFT windows (defaults to win_length // 4)

25

pad: Two-sided padding of signal

26

window_fn: Window function (e.g., torch.hann_window, torch.hamming_window)

27

power: Exponent for magnitude (1.0 for energy, 2.0 for power, None for complex)

28

normalized: Whether to normalize by window and n_fft

29

wkwargs: Additional arguments for window function

30

center: Whether to pad waveform on both sides

31

pad_mode: Padding mode for centering

32

onesided: Controls whether to return half of results

33

"""

34

35

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

36

"""

37

Args:

38

waveform: Input tensor (..., time)

39

40

Returns:

41

Tensor: Spectrogram (..., freq, time)

42

"""

43

44

class InverseSpectrogram(torch.nn.Module):

45

"""Reconstruct waveform from spectrogram using inverse STFT."""

46

47

def __init__(self, n_fft: int = 400, win_length: Optional[int] = None,

48

hop_length: Optional[int] = None, pad: int = 0,

49

window_fn: Callable[..., torch.Tensor] = torch.hann_window,

50

normalized: bool = False, wkwargs: Optional[Dict[str, Any]] = None,

51

center: bool = True, pad_mode: str = "reflect",

52

onesided: bool = True, length: Optional[int] = None) -> None:

53

"""

54

Args:

55

length: Expected length of reconstructed signal

56

(other parameters same as Spectrogram)

57

"""

58

59

def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:

60

"""

61

Args:

62

spectrogram: Input spectrogram (..., freq, time)

63

64

Returns:

65

Tensor: Reconstructed waveform (..., time)

66

"""

67

68

class GriffinLim(torch.nn.Module):

69

"""Reconstruct waveform from magnitude spectrogram using Griffin-Lim algorithm."""

70

71

def __init__(self, n_fft: int = 400, n_iter: int = 32, win_length: Optional[int] = None,

72

hop_length: Optional[int] = None, window_fn: Callable[..., torch.Tensor] = torch.hann_window,

73

power: float = 2.0, wkwargs: Optional[Dict[str, Any]] = None,

74

momentum: float = 0.99, length: Optional[int] = None,

75

rand_init: bool = True) -> None:

76

"""

77

Args:

78

n_iter: Number of Griffin-Lim iterations

79

power: Exponent applied to spectrogram

80

momentum: Momentum parameter for fast Griffin-Lim

81

rand_init: Whether to initialize with random phase

82

(other parameters same as Spectrogram)

83

"""

84

85

def forward(self, specgram: torch.Tensor) -> torch.Tensor:

86

"""

87

Args:

88

specgram: Magnitude spectrogram (..., freq, time)

89

90

Returns:

91

Tensor: Reconstructed waveform (..., time)

92

"""

93

```

94

95

### Mel-Scale Transforms

96

97

Transforms for mel-scale processing commonly used in speech and music analysis.

98

99

```python { .api }

100

class MelSpectrogram(torch.nn.Module):

101

"""Compute mel-scale spectrogram."""

102

103

def __init__(self, sample_rate: int = 16000, n_fft: int = 400,

104

win_length: Optional[int] = None, hop_length: Optional[int] = None,

105

f_min: float = 0.0, f_max: Optional[float] = None, n_mels: int = 128,

106

window_fn: Callable[..., torch.Tensor] = torch.hann_window,

107

power: float = 2.0, normalized: bool = False,

108

wkwargs: Optional[Dict[str, Any]] = None, center: bool = True,

109

pad_mode: str = "reflect", onesided: bool = True,

110

norm: Optional[str] = None, mel_scale: str = "htk") -> None:

111

"""

112

Args:

113

sample_rate: Sample rate of audio

114

f_min: Minimum frequency

115

f_max: Maximum frequency (defaults to sample_rate // 2)

116

n_mels: Number of mel filter banks

117

norm: Normalization method ("slaney" or None)

118

mel_scale: Scale to use ("htk" or "slaney")

119

(other parameters same as Spectrogram)

120

"""

121

122

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

123

"""

124

Args:

125

waveform: Input tensor (..., time)

126

127

Returns:

128

Tensor: Mel spectrogram (..., n_mels, time)

129

"""

130

131

class MelScale(torch.nn.Module):

132

"""Convert normal spectrogram to mel-scale spectrogram."""

133

134

def __init__(self, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0,

135

f_max: Optional[float] = None, n_stft: Optional[int] = None,

136

norm: Optional[str] = None, mel_scale: str = "htk") -> None:

137

"""

138

Args:

139

n_mels: Number of mel filter banks

140

sample_rate: Sample rate of audio

141

f_min: Minimum frequency

142

f_max: Maximum frequency

143

n_stft: Number of STFT frequency bins (typically n_fft // 2 + 1)

144

norm: Normalization method

145

mel_scale: Scale to use

146

"""

147

148

def forward(self, specgram: torch.Tensor) -> torch.Tensor:

149

"""

150

Args:

151

specgram: Input spectrogram (..., freq, time)

152

153

Returns:

154

Tensor: Mel-scale spectrogram (..., n_mels, time)

155

"""

156

157

class InverseMelScale(torch.nn.Module):

158

"""Solve for normal spectrogram from mel-scale spectrogram using iterative method."""

159

160

def __init__(self, n_stft: int, n_mels: int = 128, sample_rate: int = 16000,

161

f_min: float = 0.0, f_max: Optional[float] = None,

162

max_iter: int = 100000, tolerance_loss: float = 1e-5,

163

tolerance_change: float = 1e-8, sgdargs: Optional[Dict[str, Any]] = None,

164

norm: Optional[str] = None, mel_scale: str = "htk") -> None:

165

"""

166

Args:

167

n_stft: Number of STFT frequency bins

168

max_iter: Maximum number of optimization iterations

169

tolerance_loss: Tolerance for loss convergence

170

tolerance_change: Tolerance for parameter change

171

sgdargs: Arguments for SGD optimizer

172

(other parameters same as MelScale)

173

"""

174

175

def forward(self, melspec: torch.Tensor) -> torch.Tensor:

176

"""

177

Args:

178

melspec: Mel-scale spectrogram (..., n_mels, time)

179

180

Returns:

181

Tensor: Linear spectrogram (..., n_stft, time)

182

"""

183

```

184

185

### Feature Extraction Transforms

186

187

Transforms for extracting common audio features.

188

189

```python { .api }

190

class MFCC(torch.nn.Module):

191

"""Compute Mel-frequency cepstral coefficients."""

192

193

def __init__(self, sample_rate: int = 16000, n_mfcc: int = 40,

194

dct_type: int = 2, norm: str = "ortho", log_mels: bool = False,

195

melkwargs: Optional[Dict[str, Any]] = None) -> None:

196

"""

197

Args:

198

sample_rate: Sample rate of audio

199

n_mfcc: Number of MFCC coefficients

200

dct_type: DCT type (2 or 3)

201

norm: DCT normalization ("ortho" or None)

202

log_mels: Whether to use log mel spectrograms

203

melkwargs: Additional arguments for MelSpectrogram

204

"""

205

206

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

207

"""

208

Args:

209

waveform: Input tensor (..., time)

210

211

Returns:

212

Tensor: MFCC coefficients (..., n_mfcc, time)

213

"""

214

215

class LFCC(torch.nn.Module):

216

"""Compute Linear-frequency cepstral coefficients."""

217

218

def __init__(self, sample_rate: int = 16000, n_lfcc: int = 40,

219

speckwargs: Optional[Dict[str, Any]] = None, n_filter: int = 128,

220

f_min: float = 0.0, f_max: Optional[float] = None,

221

dct_type: int = 2, norm: str = "ortho", log_lf: bool = False) -> None:

222

"""

223

Args:

224

sample_rate: Sample rate of audio

225

n_lfcc: Number of LFCC coefficients

226

speckwargs: Additional arguments for Spectrogram

227

n_filter: Number of linear filter banks

228

f_min: Minimum frequency

229

f_max: Maximum frequency

230

dct_type: DCT type

231

norm: DCT normalization

232

log_lf: Whether to use log linear spectrograms

233

"""

234

235

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

236

"""

237

Args:

238

waveform: Input tensor (..., time)

239

240

Returns:

241

Tensor: LFCC coefficients (..., n_lfcc, time)

242

"""

243

244

class ComputeDeltas(torch.nn.Module):

245

"""Compute delta features (first derivatives) of input features."""

246

247

def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:

248

"""

249

Args:

250

win_length: Window length for delta computation

251

mode: Padding mode for computing deltas

252

"""

253

254

def forward(self, specgram: torch.Tensor) -> torch.Tensor:

255

"""

256

Args:

257

specgram: Input features (..., freq, time)

258

259

Returns:

260

Tensor: Delta features with same shape

261

"""

262

263

class SpectralCentroid(torch.nn.Module):

264

"""Compute spectral centroid."""

265

266

def __init__(self, sample_rate: int, n_fft: int = 400, win_length: Optional[int] = None,

267

hop_length: Optional[int] = None, pad: int = 0,

268

window_fn: Callable[..., torch.Tensor] = torch.hann_window,

269

wkwargs: Optional[Dict[str, Any]] = None) -> None:

270

"""

271

Args:

272

sample_rate: Sample rate of audio

273

(other parameters same as Spectrogram)

274

"""

275

276

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

277

"""

278

Args:

279

waveform: Input tensor (..., time)

280

281

Returns:

282

Tensor: Spectral centroid (..., time)

283

"""

284

285

class Loudness(torch.nn.Module):

286

"""Compute loudness using ITU-R BS.1770-4 standard."""

287

288

def __init__(self, sample_rate: int) -> None:

289

"""

290

Args:

291

sample_rate: Sample rate of audio

292

"""

293

294

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

295

"""

296

Args:

297

waveform: Input tensor (..., time)

298

299

Returns:

300

Tensor: Loudness in LUFS

301

"""

302

```

303

304

### Amplitude and Encoding Transforms

305

306

Transforms for amplitude scaling and audio encoding.

307

308

```python { .api }

309

class AmplitudeToDB(torch.nn.Module):

310

"""Convert amplitude spectrogram to decibel scale."""

311

312

def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:

313

"""

314

Args:

315

stype: Spectrogram type ("power" or "magnitude")

316

top_db: Minimum negative cut-off in decibels

317

"""

318

319

def forward(self, x: torch.Tensor) -> torch.Tensor:

320

"""

321

Args:

322

x: Input spectrogram (..., freq, time)

323

324

Returns:

325

Tensor: Spectrogram in decibel scale

326

"""

327

328

class MuLawEncoding(torch.nn.Module):

329

"""Encode waveform using mu-law companding."""

330

331

def __init__(self, quantization_channels: int = 256) -> None:

332

"""

333

Args:

334

quantization_channels: Number of quantization levels

335

"""

336

337

def forward(self, x: torch.Tensor) -> torch.Tensor:

338

"""

339

Args:

340

x: Input waveform (..., time)

341

342

Returns:

343

Tensor: Mu-law encoded signal

344

"""

345

346

class MuLawDecoding(torch.nn.Module):

347

"""Decode mu-law encoded waveform."""

348

349

def __init__(self, quantization_channels: int = 256) -> None:

350

"""

351

Args:

352

quantization_channels: Number of quantization levels

353

"""

354

355

def forward(self, x_mu: torch.Tensor) -> torch.Tensor:

356

"""

357

Args:

358

x_mu: Mu-law encoded signal (..., time)

359

360

Returns:

361

Tensor: Decoded waveform

362

"""

363

```

364

365

### Resampling and Time Manipulation

366

367

Transforms for changing sample rates and temporal characteristics.

368

369

```python { .api }

370

class Resample(torch.nn.Module):

371

"""Resample waveform to different sample rate."""

372

373

def __init__(self, orig_freq: int = 16000, new_freq: int = 16000,

374

resampling_method: str = "sinc_interp_kaiser",

375

lowpass_filter_width: int = 6, rolloff: float = 0.99,

376

beta: Optional[float] = None, dtype: torch.dtype = torch.float32) -> None:

377

"""

378

Args:

379

orig_freq: Original sample rate

380

new_freq: Target sample rate

381

resampling_method: Resampling algorithm

382

lowpass_filter_width: Width of lowpass filter

383

rolloff: Roll-off frequency

384

beta: Shape parameter for Kaiser window

385

dtype: Output data type

386

"""

387

388

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

389

"""

390

Args:

391

waveform: Input tensor (..., time)

392

393

Returns:

394

Tensor: Resampled waveform

395

"""

396

397

class Speed(torch.nn.Module):

398

"""Adjust playback speed by resampling."""

399

400

def __init__(self, orig_freq: int, factor: float) -> None:

401

"""

402

Args:

403

orig_freq: Original sample rate

404

factor: Speed factor (>1.0 = faster, <1.0 = slower)

405

"""

406

407

def forward(self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor:

408

"""

409

Args:

410

waveform: Input tensor (..., time)

411

lengths: Length of each sequence in batch

412

413

Returns:

414

Tensor: Speed-adjusted waveform

415

"""

416

417

class TimeStretch(torch.nn.Module):

418

"""Stretch time axis of spectrogram without changing pitch."""

419

420

def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201,

421

fixed_rate: Optional[float] = None) -> None:

422

"""

423

Args:

424

hop_length: Hop length for phase vocoder

425

n_freq: Number of frequency bins

426

fixed_rate: Fixed stretch rate (None for variable rate)

427

"""

428

429

def forward(self, complex_specgrams: torch.Tensor, rate: float = 1.0) -> torch.Tensor:

430

"""

431

Args:

432

complex_specgrams: Complex spectrogram (..., freq, time)

433

rate: Stretch rate (>1.0 = slower, <1.0 = faster)

434

435

Returns:

436

Tensor: Time-stretched spectrogram

437

"""

438

439

class PitchShift(torch.nn.Module):

440

"""Shift pitch without changing duration."""

441

442

def __init__(self, sample_rate: int, n_steps: float, bins_per_octave: int = 12,

443

n_fft: int = 512, win_length: Optional[int] = None,

444

hop_length: Optional[int] = None,

445

window: Optional[torch.Tensor] = None) -> None:

446

"""

447

Args:

448

sample_rate: Sample rate

449

n_steps: Number of semitones to shift

450

bins_per_octave: Number of steps per octave

451

n_fft: FFT size

452

win_length: Window length

453

hop_length: Hop length

454

window: Window function

455

"""

456

457

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

458

"""

459

Args:

460

waveform: Input tensor (..., time)

461

462

Returns:

463

Tensor: Pitch-shifted waveform

464

"""

465

```

466

467

### Data Augmentation Transforms

468

469

Transforms for data augmentation in machine learning training.

470

471

```python { .api }

472

class FrequencyMasking(torch.nn.Module):

473

"""Apply frequency masking to spectrograms."""

474

475

def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:

476

"""

477

Args:

478

freq_mask_param: Maximum frequency mask length

479

iid_masks: Whether to apply independent masks to each example in batch

480

"""

481

482

def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:

483

"""

484

Args:

485

specgram: Input spectrogram (..., freq, time)

486

mask_value: Value to use for masked regions

487

488

Returns:

489

Tensor: Masked spectrogram

490

"""

491

492

class TimeMasking(torch.nn.Module):

493

"""Apply time masking to spectrograms."""

494

495

def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:

496

"""

497

Args:

498

time_mask_param: Maximum time mask length

499

iid_masks: Whether to apply independent masks

500

p: Probability of applying mask

501

"""

502

503

def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:

504

"""

505

Args:

506

specgram: Input spectrogram (..., freq, time)

507

mask_value: Value to use for masked regions

508

509

Returns:

510

Tensor: Masked spectrogram

511

"""

512

513

class SpecAugment(torch.nn.Module):

514

"""Apply SpecAugment data augmentation."""

515

516

def __init__(self, n_time_masks: int = 1, time_mask_param: int = 80,

517

n_freq_masks: int = 1, freq_mask_param: int = 80,

518

iid_masks: bool = False) -> None:

519

"""

520

Args:

521

n_time_masks: Number of time masks

522

time_mask_param: Maximum time mask length

523

n_freq_masks: Number of frequency masks

524

freq_mask_param: Maximum frequency mask length

525

iid_masks: Whether to apply independent masks

526

"""

527

528

def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:

529

"""

530

Args:

531

specgram: Input spectrogram (..., freq, time)

532

mask_value: Value to use for masked regions

533

534

Returns:

535

Tensor: Augmented spectrogram

536

"""

537

538

class AddNoise(torch.nn.Module):

539

"""Add noise to waveform."""

540

541

def __init__(self, noise: torch.Tensor, snr: torch.Tensor,

542

lengths: Optional[torch.Tensor] = None) -> None:

543

"""

544

Args:

545

noise: Noise tensor to add

546

snr: Signal-to-noise ratio in dB

547

lengths: Length of each sequence in batch

548

"""

549

550

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

551

"""

552

Args:

553

waveform: Input tensor (..., time)

554

555

Returns:

556

Tensor: Waveform with added noise

557

"""

558

559

class SpeedPerturbation(torch.nn.Module):

560

"""Apply speed perturbation augmentation by randomly sampling from given factors."""

561

562

def __init__(self, orig_freq: int, factors: Sequence[float]) -> None:

563

"""

564

Args:

565

orig_freq: Original frequency of the signals

566

factors: Factors by which to adjust speed. Values >1.0 compress time, <1.0 stretch time

567

"""

568

569

def forward(self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

570

"""

571

Args:

572

waveform: Input signals (..., time)

573

lengths: Valid lengths of signals (...). Default: None

574

575

Returns:

576

Tuple[Tensor, Optional[Tensor]]: Speed-adjusted waveform and updated lengths

577

"""

578

```

579

580

### Audio Processing Transforms

581

582

Basic audio processing transforms for volume, fading, and emphasis.

583

584

```python { .api }

585

class Fade(torch.nn.Module):

586

"""Add a fade in and/or fade out to a waveform."""

587

588

def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:

589

"""

590

Args:

591

fade_in_len: Length of fade-in (time frames). Default: 0

592

fade_out_len: Length of fade-out (time frames). Default: 0

593

fade_shape: Shape of fade. Must be one of: "quarter_sine", "half_sine",

594

"linear", "logarithmic", "exponential". Default: "linear"

595

"""

596

597

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

598

"""

599

Args:

600

waveform: Input tensor (..., time)

601

602

Returns:

603

Tensor: Faded waveform with same shape

604

"""

605

606

class Vol(torch.nn.Module):

607

"""Adjust volume of waveform."""

608

609

def __init__(self, gain: float, gain_type: str = "amplitude") -> None:

610

"""

611

Args:

612

gain: Interpreted according to gain_type:

613

- amplitude: positive amplitude ratio

614

- power: power (voltage squared)

615

- db: gain in decibels

616

gain_type: Type of gain. One of: "amplitude", "power", "db". Default: "amplitude"

617

"""

618

619

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

620

"""

621

Args:

622

waveform: Input tensor (..., time)

623

624

Returns:

625

Tensor: Volume-adjusted waveform with same shape

626

"""

627

628

class Preemphasis(torch.nn.Module):

629

"""Pre-emphasizes a waveform along its last dimension."""

630

631

def __init__(self, coeff: float = 0.97) -> None:

632

"""

633

Args:

634

coeff: Pre-emphasis coefficient. Typically between 0.0 and 1.0. Default: 0.97

635

"""

636

637

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

638

"""

639

Args:

640

waveform: Input tensor (..., time)

641

642

Returns:

643

Tensor: Pre-emphasized waveform with same shape

644

"""

645

646

class Deemphasis(torch.nn.Module):

647

"""De-emphasizes a waveform along its last dimension."""

648

649

def __init__(self, coeff: float = 0.97) -> None:

650

"""

651

Args:

652

coeff: De-emphasis coefficient. Typically between 0.0 and 1.0. Default: 0.97

653

"""

654

655

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

656

"""

657

Args:

658

waveform: Input tensor (..., time)

659

660

Returns:

661

Tensor: De-emphasized waveform with same shape

662

"""

663

```

664

665

### Convolution Transforms

666

667

Convolution-based transforms for audio processing.

668

669

```python { .api }

670

class Convolve(torch.nn.Module):

671

"""Convolves inputs along their last dimension using the direct method."""

672

673

def __init__(self, mode: str = "full") -> None:

674

"""

675

Args:

676

mode: Must be one of ("full", "valid", "same").

677

- "full": Returns full convolution result (..., N + M - 1)

678

- "valid": Returns overlap segment (..., max(N, M) - min(N, M) + 1)

679

- "same": Returns center segment (..., N)

680

Default: "full"

681

"""

682

683

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

684

"""

685

Args:

686

x: First convolution operand (..., N)

687

y: Second convolution operand (..., M)

688

689

Returns:

690

Tensor: Convolution result with shape dictated by mode

691

"""

692

693

class FFTConvolve(torch.nn.Module):

694

"""Convolves inputs along their last dimension using FFT. Much faster than Convolve for large inputs."""

695

696

def __init__(self, mode: str = "full") -> None:

697

"""

698

Args:

699

mode: Must be one of ("full", "valid", "same"). Same as Convolve. Default: "full"

700

"""

701

702

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

703

"""

704

Args:

705

x: First convolution operand (..., N)

706

y: Second convolution operand (..., M)

707

708

Returns:

709

Tensor: FFT convolution result (always float tensors)

710

"""

711

```

712

713

### Multi-Channel Beamforming Transforms

714

715

Advanced multi-channel transforms for beamforming and array processing.

716

717

```python { .api }

718

class PSD(torch.nn.Module):

719

"""Compute cross-channel power spectral density (PSD) matrix."""

720

721

def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15) -> None:

722

"""

723

Args:

724

multi_mask: If True, only accepts multi-channel Time-Frequency masks. Default: False

725

normalize: If True, normalize the mask along the time dimension. Default: True

726

eps: Value to add to denominator in mask normalization. Default: 1e-15

727

"""

728

729

def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:

730

"""

731

Args:

732

specgram: Multi-channel complex-valued spectrum (..., channel, freq, time)

733

mask: Time-Frequency mask for normalization (..., freq, time) or (..., channel, freq, time)

734

735

Returns:

736

Tensor: Complex-valued PSD matrix (..., freq, channel, channel)

737

"""

738

739

class MVDR(torch.nn.Module):

740

"""Minimum Variance Distortionless Response (MVDR) beamforming with Time-Frequency masks."""

741

742

def __init__(self, ref_channel: int = 0, solution: str = "ref_channel",

743

multi_mask: bool = False, diag_loading: bool = True,

744

diag_eps: float = 1e-7, online: bool = False) -> None:

745

"""

746

Args:

747

ref_channel: Reference channel for beamforming. Default: 0

748

solution: Solution method. One of ["ref_channel", "stv_evd", "stv_power"]. Default: "ref_channel"

749

multi_mask: If True, accepts multi-channel masks. Default: False

750

diag_loading: If True, applies diagonal loading to noise covariance. Default: True

751

diag_eps: Diagonal loading coefficient. Default: 1e-7

752

online: If True, updates weights based on previous covariance matrices. Default: False

753

"""

754

755

def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:

756

"""

757

Args:

758

specgram: Multi-channel noisy spectrum (..., channel, freq, time)

759

mask_s: Time-Frequency mask for target speech

760

mask_n: Time-Frequency mask for noise

761

762

Returns:

763

Tensor: Enhanced single-channel spectrum (..., freq, time)

764

"""

765

766

class SoudenMVDR(torch.nn.Module):

767

"""MVDR beamforming using Souden's method."""

768

769

def __init__(self, ref_channel: int = 0, multi_mask: bool = False,

770

diag_loading: bool = True, diag_eps: float = 1e-7) -> None:

771

"""

772

Args:

773

ref_channel: Reference channel for beamforming. Default: 0

774

multi_mask: If True, accepts multi-channel masks. Default: False

775

diag_loading: If True, applies diagonal loading. Default: True

776

diag_eps: Diagonal loading coefficient. Default: 1e-7

777

"""

778

779

def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:

780

"""

781

Args:

782

specgram: Multi-channel noisy spectrum (..., channel, freq, time)

783

mask_s: Time-Frequency mask for target speech

784

mask_n: Time-Frequency mask for noise

785

786

Returns:

787

Tensor: Enhanced single-channel spectrum using Souden method

788

"""

789

790

class RTFMVDR(torch.nn.Module):

791

"""MVDR beamforming using Relative Transfer Function (RTF)."""

792

793

def __init__(self, ref_channel: int = 0, multi_mask: bool = False,

794

diag_loading: bool = True, diag_eps: float = 1e-7) -> None:

795

"""

796

Args:

797

ref_channel: Reference channel for beamforming. Default: 0

798

multi_mask: If True, accepts multi-channel masks. Default: False

799

diag_loading: If True, applies diagonal loading. Default: True

800

diag_eps: Diagonal loading coefficient. Default: 1e-7

801

"""

802

803

def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:

804

"""

805

Args:

806

specgram: Multi-channel noisy spectrum (..., channel, freq, time)

807

mask_s: Time-Frequency mask for target speech

808

mask_n: Time-Frequency mask for noise

809

810

Returns:

811

Tensor: Enhanced single-channel spectrum using RTF method

812

"""

813

```

814

815

### Advanced Processing Transforms

816

817

Specialized transforms for feature processing and analysis.

818

819

```python { .api }

820

class SlidingWindowCmn(torch.nn.Module):

821

"""Apply sliding-window cepstral mean (and optionally variance) normalization per utterance."""

822

823

def __init__(self, cmn_window: int = 600, min_cmn_window: int = 100,

824

center: bool = False, norm_vars: bool = False) -> None:

825

"""

826

Args:

827

cmn_window: Window in frames for running average CMN computation. Default: 600

828

min_cmn_window: Minimum CMN window used at start of decoding. Default: 100

829

center: If True, use centered window; if False, window is to the left. Default: False

830

norm_vars: If True, normalize variance to one. Default: False

831

"""

832

833

def forward(self, specgram: torch.Tensor) -> torch.Tensor:

834

"""

835

Args:

836

specgram: Spectrogram (..., time, freq)

837

838

Returns:

839

Tensor: CMN normalized spectrogram with same shape

840

"""

841

842

class Vad(torch.nn.Module):

843

"""Voice Activity Detector. Similar to SoX implementation."""

844

845

def __init__(self, sample_rate: int, trigger_level: float = 7.0, trigger_time: float = 0.25,

846

search_time: float = 1.0, allowed_gap: float = 0.25, pre_trigger_time: float = 0.0,

847

boot_time: float = 0.35, noise_up_time: float = 0.1, noise_down_time: float = 0.01,

848

noise_reduction_amount: float = 1.35, measure_freq: float = 20.0,

849

measure_duration: Optional[float] = None, measure_smooth_time: float = 0.4,

850

hp_filter_freq: float = 50.0, lp_filter_freq: float = 6000.0,

851

hp_lifter_freq: float = 150.0, lp_lifter_freq: float = 2000.0) -> None:

852

"""

853

Args:

854

sample_rate: Sample rate of audio signal

855

trigger_level: Measurement level used to trigger activity detection. Default: 7.0

856

trigger_time: Time constant to help ignore short bursts. Default: 0.25

857

search_time: Amount of audio to search for quieter bursts. Default: 1.0

858

allowed_gap: Allowed gap between quieter bursts. Default: 0.25

859

pre_trigger_time: Amount of audio to preserve before trigger. Default: 0.0

860

boot_time: Time for initial noise estimate. Default: 0.35

861

noise_up_time: Time constant for increasing noise level. Default: 0.1

862

noise_down_time: Time constant for decreasing noise level. Default: 0.01

863

noise_reduction_amount: Amount of noise reduction. Default: 1.35

864

measure_freq: Frequency of algorithm processing. Default: 20.0

865

measure_duration: Measurement duration. Default: None (twice measurement period)

866

measure_smooth_time: Time constant for spectral smoothing. Default: 0.4

867

hp_filter_freq: High-pass filter frequency. Default: 50.0

868

lp_filter_freq: Low-pass filter frequency. Default: 6000.0

869

hp_lifter_freq: High-pass lifter frequency. Default: 150.0

870

lp_lifter_freq: Low-pass lifter frequency. Default: 2000.0

871

"""

872

873

def forward(self, waveform: torch.Tensor) -> torch.Tensor:

874

"""

875

Args:

876

waveform: Input tensor (..., time)

877

878

Returns:

879

Tensor: Voice activity detection result

880

"""

881

```

882

883

### Loss Functions

884

885

Loss functions for training neural networks with audio data.

886

887

```python { .api }

888

class RNNTLoss(torch.nn.Module):

889

"""Compute the RNN Transducer loss from Sequence Transduction with Recurrent Neural Networks."""

890

891

def __init__(self, blank: int = -1, clamp: float = -1.0, reduction: str = "mean",

892

fused_log_softmax: bool = True) -> None:

893

"""

894

Args:

895

blank: Blank label. Default: -1

896

clamp: Clamp for gradients. Default: -1

897

reduction: Specifies reduction to apply: "none", "mean", or "sum". Default: "mean"

898

fused_log_softmax: Set to False if calling log_softmax outside of loss. Default: True

899

"""

900

901

def forward(self, logits: torch.Tensor, targets: torch.Tensor, logit_lengths: torch.Tensor,

902

target_lengths: torch.Tensor) -> torch.Tensor:

903

"""

904

Args:

905

logits: Tensor with shape (N, T, U, V) where N=batch, T=time, U=target, V=vocab

906

targets: Tensor with shape (N, S) where S=target sequence length

907

logit_lengths: Tensor with shape (N,) representing lengths of logits

908

target_lengths: Tensor with shape (N,) representing lengths of targets

909

910

Returns:

911

Tensor: RNN Transducer loss

912

"""

913

```

914

915

Usage example combining multiple transforms:

916

917

```python

918

import torch

919

import torchaudio

920

from torchaudio import transforms as T

921

922

# Create a processing pipeline

923

transform_pipeline = torch.nn.Sequential(

924

T.Resample(orig_freq=44100, new_freq=16000), # Resample to 16kHz

925

T.MelSpectrogram(

926

sample_rate=16000,

927

n_fft=1024,

928

hop_length=256,

929

n_mels=80

930

), # Convert to mel spectrogram

931

T.AmplitudeToDB(stype="power"), # Convert to dB scale

932

T.FrequencyMasking(freq_mask_param=15), # Apply frequency masking

933

T.TimeMasking(time_mask_param=35) # Apply time masking

934

)

935

936

# Load and process audio

937

waveform, orig_sr = torchaudio.load("audio.wav")

938

processed = transform_pipeline(waveform)

939

```

940

941

These transforms provide the building blocks for creating sophisticated audio processing pipelines that integrate seamlessly with PyTorch's neural network ecosystem.