import logging import torch from typing import Optional from torch import Tensor from torch.nn.functional import pad from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import convert_decibels_to_amplitude_ratio from ..utils.object_dict import ObjectDict class SpliceOut(BaseWaveformTransform): """ spliceout augmentation proposed in https://arxiv.org/pdf/2110.00046.pdf silence padding is added at the end to retain the audio length. """ supported_modes = {"per_batch", "per_example"} requires_sample_rate = True def __init__( self, num_time_intervals=8, max_width=400, mode: str = "per_example", p: float = 0.5, p_mode: Optional[str] = None, sample_rate: Optional[int] = None, target_rate: Optional[int] = None, output_type: Optional[str] = None, ): """ param num_time_intervals: number of time intervals to spliceout param max_width: maximum width of each spliceout in milliseconds param n_fft: size of FFT """ super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) self.num_time_intervals = num_time_intervals self.max_width = max_width def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): self.transform_parameters["splice_lengths"] = torch.randint( low=0, high=int(sample_rate * self.max_width * 1e-3), size=(samples.shape[0], self.num_time_intervals), ) def apply_transform( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ) -> ObjectDict: spliceout_samples = [] for i in range(samples.shape[0]): random_lengths = self.transform_parameters["splice_lengths"][i] sample = samples[i][:, :] for j in range(self.num_time_intervals): start = torch.randint( 0, sample.shape[-1] - random_lengths[j], size=(1,), ) if random_lengths[j] % 2 != 0: random_lengths[j] += 1 hann_window_len = random_lengths[j] hann_window = torch.hann_window(hann_window_len, device=samples.device) hann_window_left, hann_window_right = ( hann_window[: hann_window_len // 2], hann_window[hann_window_len // 2 :], ) fading_out, fading_in = ( sample[:, start : start + random_lengths[j] // 2], sample[:, start + random_lengths[j] // 2 : start + random_lengths[j]], ) crossfade = hann_window_right * fading_out + hann_window_left * fading_in sample = torch.cat( ( sample[:, :start], crossfade[:, :], sample[:, start + random_lengths[j] :], ), dim=-1, ) padding = torch.zeros( (samples[i].shape[0], samples[i].shape[-1] - sample.shape[-1]), dtype=torch.float32, device=sample.device, ) sample = torch.cat((sample, padding), dim=-1) spliceout_samples.append(sample.unsqueeze(0)) return ObjectDict( samples=torch.cat(spliceout_samples, dim=0), sample_rate=sample_rate, targets=targets, target_rate=target_rate, )
Memory