0
# Audio Transforms
1
2
PyTorch-compatible transform classes for building differentiable audio processing pipelines. These transforms are torch.nn.Module subclasses that can be composed with neural networks and trained end-to-end using automatic differentiation.
3
4
## Capabilities
5
6
### Spectral Transforms
7
8
Core spectral analysis transforms for converting between time and frequency domains.
9
10
```python { .api }
11
class Spectrogram(torch.nn.Module):
12
"""Compute spectrogram of audio signal."""
13
14
def __init__(self, n_fft: int = 400, win_length: Optional[int] = None,
15
hop_length: Optional[int] = None, pad: int = 0,
16
window_fn: Callable[..., torch.Tensor] = torch.hann_window,
17
power: Optional[float] = 2.0, normalized: bool = False,
18
wkwargs: Optional[Dict[str, Any]] = None, center: bool = True,
19
pad_mode: str = "reflect", onesided: bool = True) -> None:
20
"""
21
Args:
22
n_fft: Size of FFT
23
win_length: Window size (defaults to n_fft)
24
hop_length: Length of hop between STFT windows (defaults to win_length // 4)
25
pad: Two-sided padding of signal
26
window_fn: Window function (e.g., torch.hann_window, torch.hamming_window)
27
power: Exponent for magnitude (1.0 for energy, 2.0 for power, None for complex)
28
normalized: Whether to normalize by window and n_fft
29
wkwargs: Additional arguments for window function
30
center: Whether to pad waveform on both sides
31
pad_mode: Padding mode for centering
32
onesided: Controls whether to return half of results
33
"""
34
35
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
36
"""
37
Args:
38
waveform: Input tensor (..., time)
39
40
Returns:
41
Tensor: Spectrogram (..., freq, time)
42
"""
43
44
class InverseSpectrogram(torch.nn.Module):
45
"""Reconstruct waveform from spectrogram using inverse STFT."""
46
47
def __init__(self, n_fft: int = 400, win_length: Optional[int] = None,
48
hop_length: Optional[int] = None, pad: int = 0,
49
window_fn: Callable[..., torch.Tensor] = torch.hann_window,
50
normalized: bool = False, wkwargs: Optional[Dict[str, Any]] = None,
51
center: bool = True, pad_mode: str = "reflect",
52
onesided: bool = True, length: Optional[int] = None) -> None:
53
"""
54
Args:
55
length: Expected length of reconstructed signal
56
(other parameters same as Spectrogram)
57
"""
58
59
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
60
"""
61
Args:
62
spectrogram: Input spectrogram (..., freq, time)
63
64
Returns:
65
Tensor: Reconstructed waveform (..., time)
66
"""
67
68
class GriffinLim(torch.nn.Module):
69
"""Reconstruct waveform from magnitude spectrogram using Griffin-Lim algorithm."""
70
71
def __init__(self, n_fft: int = 400, n_iter: int = 32, win_length: Optional[int] = None,
72
hop_length: Optional[int] = None, window_fn: Callable[..., torch.Tensor] = torch.hann_window,
73
power: float = 2.0, wkwargs: Optional[Dict[str, Any]] = None,
74
momentum: float = 0.99, length: Optional[int] = None,
75
rand_init: bool = True) -> None:
76
"""
77
Args:
78
n_iter: Number of Griffin-Lim iterations
79
power: Exponent applied to spectrogram
80
momentum: Momentum parameter for fast Griffin-Lim
81
rand_init: Whether to initialize with random phase
82
(other parameters same as Spectrogram)
83
"""
84
85
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
86
"""
87
Args:
88
specgram: Magnitude spectrogram (..., freq, time)
89
90
Returns:
91
Tensor: Reconstructed waveform (..., time)
92
"""
93
```
94
95
### Mel-Scale Transforms
96
97
Transforms for mel-scale processing commonly used in speech and music analysis.
98
99
```python { .api }
100
class MelSpectrogram(torch.nn.Module):
101
"""Compute mel-scale spectrogram."""
102
103
def __init__(self, sample_rate: int = 16000, n_fft: int = 400,
104
win_length: Optional[int] = None, hop_length: Optional[int] = None,
105
f_min: float = 0.0, f_max: Optional[float] = None, n_mels: int = 128,
106
window_fn: Callable[..., torch.Tensor] = torch.hann_window,
107
power: float = 2.0, normalized: bool = False,
108
wkwargs: Optional[Dict[str, Any]] = None, center: bool = True,
109
pad_mode: str = "reflect", onesided: bool = True,
110
norm: Optional[str] = None, mel_scale: str = "htk") -> None:
111
"""
112
Args:
113
sample_rate: Sample rate of audio
114
f_min: Minimum frequency
115
f_max: Maximum frequency (defaults to sample_rate // 2)
116
n_mels: Number of mel filter banks
117
norm: Normalization method ("slaney" or None)
118
mel_scale: Scale to use ("htk" or "slaney")
119
(other parameters same as Spectrogram)
120
"""
121
122
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
123
"""
124
Args:
125
waveform: Input tensor (..., time)
126
127
Returns:
128
Tensor: Mel spectrogram (..., n_mels, time)
129
"""
130
131
class MelScale(torch.nn.Module):
132
"""Convert normal spectrogram to mel-scale spectrogram."""
133
134
def __init__(self, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0,
135
f_max: Optional[float] = None, n_stft: Optional[int] = None,
136
norm: Optional[str] = None, mel_scale: str = "htk") -> None:
137
"""
138
Args:
139
n_mels: Number of mel filter banks
140
sample_rate: Sample rate of audio
141
f_min: Minimum frequency
142
f_max: Maximum frequency
143
n_stft: Number of STFT frequency bins (typically n_fft // 2 + 1)
144
norm: Normalization method
145
mel_scale: Scale to use
146
"""
147
148
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
149
"""
150
Args:
151
specgram: Input spectrogram (..., freq, time)
152
153
Returns:
154
Tensor: Mel-scale spectrogram (..., n_mels, time)
155
"""
156
157
class InverseMelScale(torch.nn.Module):
158
"""Solve for normal spectrogram from mel-scale spectrogram using iterative method."""
159
160
def __init__(self, n_stft: int, n_mels: int = 128, sample_rate: int = 16000,
161
f_min: float = 0.0, f_max: Optional[float] = None,
162
max_iter: int = 100000, tolerance_loss: float = 1e-5,
163
tolerance_change: float = 1e-8, sgdargs: Optional[Dict[str, Any]] = None,
164
norm: Optional[str] = None, mel_scale: str = "htk") -> None:
165
"""
166
Args:
167
n_stft: Number of STFT frequency bins
168
max_iter: Maximum number of optimization iterations
169
tolerance_loss: Tolerance for loss convergence
170
tolerance_change: Tolerance for parameter change
171
sgdargs: Arguments for SGD optimizer
172
(other parameters same as MelScale)
173
"""
174
175
def forward(self, melspec: torch.Tensor) -> torch.Tensor:
176
"""
177
Args:
178
melspec: Mel-scale spectrogram (..., n_mels, time)
179
180
Returns:
181
Tensor: Linear spectrogram (..., n_stft, time)
182
"""
183
```
184
185
### Feature Extraction Transforms
186
187
Transforms for extracting common audio features.
188
189
```python { .api }
190
class MFCC(torch.nn.Module):
191
"""Compute Mel-frequency cepstral coefficients."""
192
193
def __init__(self, sample_rate: int = 16000, n_mfcc: int = 40,
194
dct_type: int = 2, norm: str = "ortho", log_mels: bool = False,
195
melkwargs: Optional[Dict[str, Any]] = None) -> None:
196
"""
197
Args:
198
sample_rate: Sample rate of audio
199
n_mfcc: Number of MFCC coefficients
200
dct_type: DCT type (2 or 3)
201
norm: DCT normalization ("ortho" or None)
202
log_mels: Whether to use log mel spectrograms
203
melkwargs: Additional arguments for MelSpectrogram
204
"""
205
206
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
207
"""
208
Args:
209
waveform: Input tensor (..., time)
210
211
Returns:
212
Tensor: MFCC coefficients (..., n_mfcc, time)
213
"""
214
215
class LFCC(torch.nn.Module):
216
"""Compute Linear-frequency cepstral coefficients."""
217
218
def __init__(self, sample_rate: int = 16000, n_lfcc: int = 40,
219
speckwargs: Optional[Dict[str, Any]] = None, n_filter: int = 128,
220
f_min: float = 0.0, f_max: Optional[float] = None,
221
dct_type: int = 2, norm: str = "ortho", log_lf: bool = False) -> None:
222
"""
223
Args:
224
sample_rate: Sample rate of audio
225
n_lfcc: Number of LFCC coefficients
226
speckwargs: Additional arguments for Spectrogram
227
n_filter: Number of linear filter banks
228
f_min: Minimum frequency
229
f_max: Maximum frequency
230
dct_type: DCT type
231
norm: DCT normalization
232
log_lf: Whether to use log linear spectrograms
233
"""
234
235
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
236
"""
237
Args:
238
waveform: Input tensor (..., time)
239
240
Returns:
241
Tensor: LFCC coefficients (..., n_lfcc, time)
242
"""
243
244
class ComputeDeltas(torch.nn.Module):
245
"""Compute delta features (first derivatives) of input features."""
246
247
def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
248
"""
249
Args:
250
win_length: Window length for delta computation
251
mode: Padding mode for computing deltas
252
"""
253
254
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
255
"""
256
Args:
257
specgram: Input features (..., freq, time)
258
259
Returns:
260
Tensor: Delta features with same shape
261
"""
262
263
class SpectralCentroid(torch.nn.Module):
264
"""Compute spectral centroid."""
265
266
def __init__(self, sample_rate: int, n_fft: int = 400, win_length: Optional[int] = None,
267
hop_length: Optional[int] = None, pad: int = 0,
268
window_fn: Callable[..., torch.Tensor] = torch.hann_window,
269
wkwargs: Optional[Dict[str, Any]] = None) -> None:
270
"""
271
Args:
272
sample_rate: Sample rate of audio
273
(other parameters same as Spectrogram)
274
"""
275
276
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
277
"""
278
Args:
279
waveform: Input tensor (..., time)
280
281
Returns:
282
Tensor: Spectral centroid (..., time)
283
"""
284
285
class Loudness(torch.nn.Module):
286
"""Compute loudness using ITU-R BS.1770-4 standard."""
287
288
def __init__(self, sample_rate: int) -> None:
289
"""
290
Args:
291
sample_rate: Sample rate of audio
292
"""
293
294
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
295
"""
296
Args:
297
waveform: Input tensor (..., time)
298
299
Returns:
300
Tensor: Loudness in LUFS
301
"""
302
```
303
304
### Amplitude and Encoding Transforms
305
306
Transforms for amplitude scaling and audio encoding.
307
308
```python { .api }
309
class AmplitudeToDB(torch.nn.Module):
310
"""Convert amplitude spectrogram to decibel scale."""
311
312
def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
313
"""
314
Args:
315
stype: Spectrogram type ("power" or "magnitude")
316
top_db: Minimum negative cut-off in decibels
317
"""
318
319
def forward(self, x: torch.Tensor) -> torch.Tensor:
320
"""
321
Args:
322
x: Input spectrogram (..., freq, time)
323
324
Returns:
325
Tensor: Spectrogram in decibel scale
326
"""
327
328
class MuLawEncoding(torch.nn.Module):
329
"""Encode waveform using mu-law companding."""
330
331
def __init__(self, quantization_channels: int = 256) -> None:
332
"""
333
Args:
334
quantization_channels: Number of quantization levels
335
"""
336
337
def forward(self, x: torch.Tensor) -> torch.Tensor:
338
"""
339
Args:
340
x: Input waveform (..., time)
341
342
Returns:
343
Tensor: Mu-law encoded signal
344
"""
345
346
class MuLawDecoding(torch.nn.Module):
347
"""Decode mu-law encoded waveform."""
348
349
def __init__(self, quantization_channels: int = 256) -> None:
350
"""
351
Args:
352
quantization_channels: Number of quantization levels
353
"""
354
355
def forward(self, x_mu: torch.Tensor) -> torch.Tensor:
356
"""
357
Args:
358
x_mu: Mu-law encoded signal (..., time)
359
360
Returns:
361
Tensor: Decoded waveform
362
"""
363
```
364
365
### Resampling and Time Manipulation
366
367
Transforms for changing sample rates and temporal characteristics.
368
369
```python { .api }
370
class Resample(torch.nn.Module):
371
"""Resample waveform to different sample rate."""
372
373
def __init__(self, orig_freq: int = 16000, new_freq: int = 16000,
374
resampling_method: str = "sinc_interp_kaiser",
375
lowpass_filter_width: int = 6, rolloff: float = 0.99,
376
beta: Optional[float] = None, dtype: torch.dtype = torch.float32) -> None:
377
"""
378
Args:
379
orig_freq: Original sample rate
380
new_freq: Target sample rate
381
resampling_method: Resampling algorithm
382
lowpass_filter_width: Width of lowpass filter
383
rolloff: Roll-off frequency
384
beta: Shape parameter for Kaiser window
385
dtype: Output data type
386
"""
387
388
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
389
"""
390
Args:
391
waveform: Input tensor (..., time)
392
393
Returns:
394
Tensor: Resampled waveform
395
"""
396
397
class Speed(torch.nn.Module):
398
"""Adjust playback speed by resampling."""
399
400
def __init__(self, orig_freq: int, factor: float) -> None:
401
"""
402
Args:
403
orig_freq: Original sample rate
404
factor: Speed factor (>1.0 = faster, <1.0 = slower)
405
"""
406
407
def forward(self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
408
"""
409
Args:
410
waveform: Input tensor (..., time)
411
lengths: Length of each sequence in batch
412
413
Returns:
414
Tensor: Speed-adjusted waveform
415
"""
416
417
class TimeStretch(torch.nn.Module):
418
"""Stretch time axis of spectrogram without changing pitch."""
419
420
def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201,
421
fixed_rate: Optional[float] = None) -> None:
422
"""
423
Args:
424
hop_length: Hop length for phase vocoder
425
n_freq: Number of frequency bins
426
fixed_rate: Fixed stretch rate (None for variable rate)
427
"""
428
429
def forward(self, complex_specgrams: torch.Tensor, rate: float = 1.0) -> torch.Tensor:
430
"""
431
Args:
432
complex_specgrams: Complex spectrogram (..., freq, time)
433
rate: Stretch rate (>1.0 = slower, <1.0 = faster)
434
435
Returns:
436
Tensor: Time-stretched spectrogram
437
"""
438
439
class PitchShift(torch.nn.Module):
440
"""Shift pitch without changing duration."""
441
442
def __init__(self, sample_rate: int, n_steps: float, bins_per_octave: int = 12,
443
n_fft: int = 512, win_length: Optional[int] = None,
444
hop_length: Optional[int] = None,
445
window: Optional[torch.Tensor] = None) -> None:
446
"""
447
Args:
448
sample_rate: Sample rate
449
n_steps: Number of semitones to shift
450
bins_per_octave: Number of steps per octave
451
n_fft: FFT size
452
win_length: Window length
453
hop_length: Hop length
454
window: Window function
455
"""
456
457
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
458
"""
459
Args:
460
waveform: Input tensor (..., time)
461
462
Returns:
463
Tensor: Pitch-shifted waveform
464
"""
465
```
466
467
### Data Augmentation Transforms
468
469
Transforms for data augmentation in machine learning training.
470
471
```python { .api }
472
class FrequencyMasking(torch.nn.Module):
473
"""Apply frequency masking to spectrograms."""
474
475
def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
476
"""
477
Args:
478
freq_mask_param: Maximum frequency mask length
479
iid_masks: Whether to apply independent masks to each example in batch
480
"""
481
482
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
483
"""
484
Args:
485
specgram: Input spectrogram (..., freq, time)
486
mask_value: Value to use for masked regions
487
488
Returns:
489
Tensor: Masked spectrogram
490
"""
491
492
class TimeMasking(torch.nn.Module):
493
"""Apply time masking to spectrograms."""
494
495
def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:
496
"""
497
Args:
498
time_mask_param: Maximum time mask length
499
iid_masks: Whether to apply independent masks
500
p: Probability of applying mask
501
"""
502
503
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
504
"""
505
Args:
506
specgram: Input spectrogram (..., freq, time)
507
mask_value: Value to use for masked regions
508
509
Returns:
510
Tensor: Masked spectrogram
511
"""
512
513
class SpecAugment(torch.nn.Module):
514
"""Apply SpecAugment data augmentation."""
515
516
def __init__(self, n_time_masks: int = 1, time_mask_param: int = 80,
517
n_freq_masks: int = 1, freq_mask_param: int = 80,
518
iid_masks: bool = False) -> None:
519
"""
520
Args:
521
n_time_masks: Number of time masks
522
time_mask_param: Maximum time mask length
523
n_freq_masks: Number of frequency masks
524
freq_mask_param: Maximum frequency mask length
525
iid_masks: Whether to apply independent masks
526
"""
527
528
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
529
"""
530
Args:
531
specgram: Input spectrogram (..., freq, time)
532
mask_value: Value to use for masked regions
533
534
Returns:
535
Tensor: Augmented spectrogram
536
"""
537
538
class AddNoise(torch.nn.Module):
539
"""Add noise to waveform."""
540
541
def __init__(self, noise: torch.Tensor, snr: torch.Tensor,
542
lengths: Optional[torch.Tensor] = None) -> None:
543
"""
544
Args:
545
noise: Noise tensor to add
546
snr: Signal-to-noise ratio in dB
547
lengths: Length of each sequence in batch
548
"""
549
550
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
551
"""
552
Args:
553
waveform: Input tensor (..., time)
554
555
Returns:
556
Tensor: Waveform with added noise
557
"""
558
559
class SpeedPerturbation(torch.nn.Module):
560
"""Apply speed perturbation augmentation by randomly sampling from given factors."""
561
562
def __init__(self, orig_freq: int, factors: Sequence[float]) -> None:
563
"""
564
Args:
565
orig_freq: Original frequency of the signals
566
factors: Factors by which to adjust speed. Values >1.0 compress time, <1.0 stretch time
567
"""
568
569
def forward(self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
570
"""
571
Args:
572
waveform: Input signals (..., time)
573
lengths: Valid lengths of signals (...). Default: None
574
575
Returns:
576
Tuple[Tensor, Optional[Tensor]]: Speed-adjusted waveform and updated lengths
577
"""
578
```
579
580
### Audio Processing Transforms
581
582
Basic audio processing transforms for volume, fading, and emphasis.
583
584
```python { .api }
585
class Fade(torch.nn.Module):
586
"""Add a fade in and/or fade out to a waveform."""
587
588
def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:
589
"""
590
Args:
591
fade_in_len: Length of fade-in (time frames). Default: 0
592
fade_out_len: Length of fade-out (time frames). Default: 0
593
fade_shape: Shape of fade. Must be one of: "quarter_sine", "half_sine",
594
"linear", "logarithmic", "exponential". Default: "linear"
595
"""
596
597
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
598
"""
599
Args:
600
waveform: Input tensor (..., time)
601
602
Returns:
603
Tensor: Faded waveform with same shape
604
"""
605
606
class Vol(torch.nn.Module):
607
"""Adjust volume of waveform."""
608
609
def __init__(self, gain: float, gain_type: str = "amplitude") -> None:
610
"""
611
Args:
612
gain: Interpreted according to gain_type:
613
- amplitude: positive amplitude ratio
614
- power: power (voltage squared)
615
- db: gain in decibels
616
gain_type: Type of gain. One of: "amplitude", "power", "db". Default: "amplitude"
617
"""
618
619
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
620
"""
621
Args:
622
waveform: Input tensor (..., time)
623
624
Returns:
625
Tensor: Volume-adjusted waveform with same shape
626
"""
627
628
class Preemphasis(torch.nn.Module):
629
"""Pre-emphasizes a waveform along its last dimension."""
630
631
def __init__(self, coeff: float = 0.97) -> None:
632
"""
633
Args:
634
coeff: Pre-emphasis coefficient. Typically between 0.0 and 1.0. Default: 0.97
635
"""
636
637
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
638
"""
639
Args:
640
waveform: Input tensor (..., time)
641
642
Returns:
643
Tensor: Pre-emphasized waveform with same shape
644
"""
645
646
class Deemphasis(torch.nn.Module):
647
"""De-emphasizes a waveform along its last dimension."""
648
649
def __init__(self, coeff: float = 0.97) -> None:
650
"""
651
Args:
652
coeff: De-emphasis coefficient. Typically between 0.0 and 1.0. Default: 0.97
653
"""
654
655
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
656
"""
657
Args:
658
waveform: Input tensor (..., time)
659
660
Returns:
661
Tensor: De-emphasized waveform with same shape
662
"""
663
```
664
665
### Convolution Transforms
666
667
Convolution-based transforms for audio processing.
668
669
```python { .api }
670
class Convolve(torch.nn.Module):
671
"""Convolves inputs along their last dimension using the direct method."""
672
673
def __init__(self, mode: str = "full") -> None:
674
"""
675
Args:
676
mode: Must be one of ("full", "valid", "same").
677
- "full": Returns full convolution result (..., N + M - 1)
678
- "valid": Returns overlap segment (..., max(N, M) - min(N, M) + 1)
679
- "same": Returns center segment (..., N)
680
Default: "full"
681
"""
682
683
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
684
"""
685
Args:
686
x: First convolution operand (..., N)
687
y: Second convolution operand (..., M)
688
689
Returns:
690
Tensor: Convolution result with shape dictated by mode
691
"""
692
693
class FFTConvolve(torch.nn.Module):
694
"""Convolves inputs along their last dimension using FFT. Much faster than Convolve for large inputs."""
695
696
def __init__(self, mode: str = "full") -> None:
697
"""
698
Args:
699
mode: Must be one of ("full", "valid", "same"). Same as Convolve. Default: "full"
700
"""
701
702
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
703
"""
704
Args:
705
x: First convolution operand (..., N)
706
y: Second convolution operand (..., M)
707
708
Returns:
709
Tensor: FFT convolution result (always float tensors)
710
"""
711
```
712
713
### Multi-Channel Beamforming Transforms
714
715
Advanced multi-channel transforms for beamforming and array processing.
716
717
```python { .api }
718
class PSD(torch.nn.Module):
719
"""Compute cross-channel power spectral density (PSD) matrix."""
720
721
def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15) -> None:
722
"""
723
Args:
724
multi_mask: If True, only accepts multi-channel Time-Frequency masks. Default: False
725
normalize: If True, normalize the mask along the time dimension. Default: True
726
eps: Value to add to denominator in mask normalization. Default: 1e-15
727
"""
728
729
def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
730
"""
731
Args:
732
specgram: Multi-channel complex-valued spectrum (..., channel, freq, time)
733
mask: Time-Frequency mask for normalization (..., freq, time) or (..., channel, freq, time)
734
735
Returns:
736
Tensor: Complex-valued PSD matrix (..., freq, channel, channel)
737
"""
738
739
class MVDR(torch.nn.Module):
740
"""Minimum Variance Distortionless Response (MVDR) beamforming with Time-Frequency masks."""
741
742
def __init__(self, ref_channel: int = 0, solution: str = "ref_channel",
743
multi_mask: bool = False, diag_loading: bool = True,
744
diag_eps: float = 1e-7, online: bool = False) -> None:
745
"""
746
Args:
747
ref_channel: Reference channel for beamforming. Default: 0
748
solution: Solution method. One of ["ref_channel", "stv_evd", "stv_power"]. Default: "ref_channel"
749
multi_mask: If True, accepts multi-channel masks. Default: False
750
diag_loading: If True, applies diagonal loading to noise covariance. Default: True
751
diag_eps: Diagonal loading coefficient. Default: 1e-7
752
online: If True, updates weights based on previous covariance matrices. Default: False
753
"""
754
755
def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
756
"""
757
Args:
758
specgram: Multi-channel noisy spectrum (..., channel, freq, time)
759
mask_s: Time-Frequency mask for target speech
760
mask_n: Time-Frequency mask for noise
761
762
Returns:
763
Tensor: Enhanced single-channel spectrum (..., freq, time)
764
"""
765
766
class SoudenMVDR(torch.nn.Module):
767
"""MVDR beamforming using Souden's method."""
768
769
def __init__(self, ref_channel: int = 0, multi_mask: bool = False,
770
diag_loading: bool = True, diag_eps: float = 1e-7) -> None:
771
"""
772
Args:
773
ref_channel: Reference channel for beamforming. Default: 0
774
multi_mask: If True, accepts multi-channel masks. Default: False
775
diag_loading: If True, applies diagonal loading. Default: True
776
diag_eps: Diagonal loading coefficient. Default: 1e-7
777
"""
778
779
def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
780
"""
781
Args:
782
specgram: Multi-channel noisy spectrum (..., channel, freq, time)
783
mask_s: Time-Frequency mask for target speech
784
mask_n: Time-Frequency mask for noise
785
786
Returns:
787
Tensor: Enhanced single-channel spectrum using Souden method
788
"""
789
790
class RTFMVDR(torch.nn.Module):
791
"""MVDR beamforming using Relative Transfer Function (RTF)."""
792
793
def __init__(self, ref_channel: int = 0, multi_mask: bool = False,
794
diag_loading: bool = True, diag_eps: float = 1e-7) -> None:
795
"""
796
Args:
797
ref_channel: Reference channel for beamforming. Default: 0
798
multi_mask: If True, accepts multi-channel masks. Default: False
799
diag_loading: If True, applies diagonal loading. Default: True
800
diag_eps: Diagonal loading coefficient. Default: 1e-7
801
"""
802
803
def forward(self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
804
"""
805
Args:
806
specgram: Multi-channel noisy spectrum (..., channel, freq, time)
807
mask_s: Time-Frequency mask for target speech
808
mask_n: Time-Frequency mask for noise
809
810
Returns:
811
Tensor: Enhanced single-channel spectrum using RTF method
812
"""
813
```
814
815
### Advanced Processing Transforms
816
817
Specialized transforms for feature processing and analysis.
818
819
```python { .api }
820
class SlidingWindowCmn(torch.nn.Module):
821
"""Apply sliding-window cepstral mean (and optionally variance) normalization per utterance."""
822
823
def __init__(self, cmn_window: int = 600, min_cmn_window: int = 100,
824
center: bool = False, norm_vars: bool = False) -> None:
825
"""
826
Args:
827
cmn_window: Window in frames for running average CMN computation. Default: 600
828
min_cmn_window: Minimum CMN window used at start of decoding. Default: 100
829
center: If True, use centered window; if False, window is to the left. Default: False
830
norm_vars: If True, normalize variance to one. Default: False
831
"""
832
833
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
834
"""
835
Args:
836
specgram: Spectrogram (..., time, freq)
837
838
Returns:
839
Tensor: CMN normalized spectrogram with same shape
840
"""
841
842
class Vad(torch.nn.Module):
843
"""Voice Activity Detector. Similar to SoX implementation."""
844
845
def __init__(self, sample_rate: int, trigger_level: float = 7.0, trigger_time: float = 0.25,
846
search_time: float = 1.0, allowed_gap: float = 0.25, pre_trigger_time: float = 0.0,
847
boot_time: float = 0.35, noise_up_time: float = 0.1, noise_down_time: float = 0.01,
848
noise_reduction_amount: float = 1.35, measure_freq: float = 20.0,
849
measure_duration: Optional[float] = None, measure_smooth_time: float = 0.4,
850
hp_filter_freq: float = 50.0, lp_filter_freq: float = 6000.0,
851
hp_lifter_freq: float = 150.0, lp_lifter_freq: float = 2000.0) -> None:
852
"""
853
Args:
854
sample_rate: Sample rate of audio signal
855
trigger_level: Measurement level used to trigger activity detection. Default: 7.0
856
trigger_time: Time constant to help ignore short bursts. Default: 0.25
857
search_time: Amount of audio to search for quieter bursts. Default: 1.0
858
allowed_gap: Allowed gap between quieter bursts. Default: 0.25
859
pre_trigger_time: Amount of audio to preserve before trigger. Default: 0.0
860
boot_time: Time for initial noise estimate. Default: 0.35
861
noise_up_time: Time constant for increasing noise level. Default: 0.1
862
noise_down_time: Time constant for decreasing noise level. Default: 0.01
863
noise_reduction_amount: Amount of noise reduction. Default: 1.35
864
measure_freq: Frequency of algorithm processing. Default: 20.0
865
measure_duration: Measurement duration. Default: None (twice measurement period)
866
measure_smooth_time: Time constant for spectral smoothing. Default: 0.4
867
hp_filter_freq: High-pass filter frequency. Default: 50.0
868
lp_filter_freq: Low-pass filter frequency. Default: 6000.0
869
hp_lifter_freq: High-pass lifter frequency. Default: 150.0
870
lp_lifter_freq: Low-pass lifter frequency. Default: 2000.0
871
"""
872
873
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
874
"""
875
Args:
876
waveform: Input tensor (..., time)
877
878
Returns:
879
Tensor: Voice activity detection result
880
"""
881
```
882
883
### Loss Functions
884
885
Loss functions for training neural networks with audio data.
886
887
```python { .api }
888
class RNNTLoss(torch.nn.Module):
889
"""Compute the RNN Transducer loss from Sequence Transduction with Recurrent Neural Networks."""
890
891
def __init__(self, blank: int = -1, clamp: float = -1.0, reduction: str = "mean",
892
fused_log_softmax: bool = True) -> None:
893
"""
894
Args:
895
blank: Blank label. Default: -1
896
clamp: Clamp for gradients. Default: -1
897
reduction: Specifies reduction to apply: "none", "mean", or "sum". Default: "mean"
898
fused_log_softmax: Set to False if calling log_softmax outside of loss. Default: True
899
"""
900
901
def forward(self, logits: torch.Tensor, targets: torch.Tensor, logit_lengths: torch.Tensor,
902
target_lengths: torch.Tensor) -> torch.Tensor:
903
"""
904
Args:
905
logits: Tensor with shape (N, T, U, V) where N=batch, T=time, U=target, V=vocab
906
targets: Tensor with shape (N, S) where S=target sequence length
907
logit_lengths: Tensor with shape (N,) representing lengths of logits
908
target_lengths: Tensor with shape (N,) representing lengths of targets
909
910
Returns:
911
Tensor: RNN Transducer loss
912
"""
913
```
914
915
Usage example combining multiple transforms:
916
917
```python
918
import torch
919
import torchaudio
920
from torchaudio import transforms as T
921
922
# Create a processing pipeline
923
transform_pipeline = torch.nn.Sequential(
924
T.Resample(orig_freq=44100, new_freq=16000), # Resample to 16kHz
925
T.MelSpectrogram(
926
sample_rate=16000,
927
n_fft=1024,
928
hop_length=256,
929
n_mels=80
930
), # Convert to mel spectrogram
931
T.AmplitudeToDB(stype="power"), # Convert to dB scale
932
T.FrequencyMasking(freq_mask_param=15), # Apply frequency masking
933
T.TimeMasking(time_mask_param=35) # Apply time masking
934
)
935
936
# Load and process audio
937
waveform, orig_sr = torchaudio.load("audio.wav")
938
processed = transform_pipeline(waveform)
939
```
940
941
These transforms provide the building blocks for creating sophisticated audio processing pipelines that integrate seamlessly with PyTorch's neural network ecosystem.