import random from pathlib import Path from typing import Union, List, Optional import torch from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform, EmptyPathException from ..utils.dsp import calculate_rms from ..utils.file import find_audio_files_in_paths from ..utils.io import Audio from ..utils.object_dict import ObjectDict class AddBackgroundNoise(BaseWaveformTransform): """ Add background noise to the input audio. """ supported_modes = {"per_batch", "per_example", "per_channel"} # Note: This transform has only partial support for multichannel audio. Noises that are not # mono get mixed down to mono before they are added to all channels in the input. supports_multichannel = True requires_sample_rate = True supports_target = True requires_target = False def __init__( self, background_paths: Union[List[Path], List[str], Path, str], min_snr_in_db: float = 3.0, max_snr_in_db: float = 30.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, sample_rate: int = None, target_rate: int = None, output_type: Optional[str] = None, ): """ :param background_paths: Either a path to a folder with audio files or a list of paths to audio files. :param min_snr_in_db: minimum SNR in dB. :param max_snr_in_db: maximum SNR in dB. :param mode: :param p: :param p_mode: :param sample_rate: """ super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) # TODO: check that one can read audio files self.background_paths = find_audio_files_in_paths(background_paths) if sample_rate is not None: self.audio = Audio(sample_rate=sample_rate, mono=True) if len(self.background_paths) == 0: raise EmptyPathException("There are no supported audio files found.") self.min_snr_in_db = min_snr_in_db self.max_snr_in_db = max_snr_in_db if self.min_snr_in_db > self.max_snr_in_db: raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tensor: pieces = [] # TODO: support repeat short samples instead of concatenating from different files missing_num_samples = target_num_samples while missing_num_samples > 0: background_path = random.choice(self.background_paths) background_num_samples = audio.get_num_samples(background_path) if background_num_samples > missing_num_samples: sample_offset = random.randint( 0, background_num_samples - missing_num_samples ) num_samples = missing_num_samples background_samples = audio( background_path, sample_offset=sample_offset, num_samples=num_samples ) missing_num_samples = 0 else: background_samples = audio(background_path) missing_num_samples -= background_num_samples pieces.append(background_samples) # the inner call to rms_normalize ensures concatenated pieces share the same RMS (1) # the outer call to rms_normalize ensures that the resulting background has an RMS of 1 # (this simplifies "apply_transform" logic) return audio.rms_normalize( torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1) ) def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): """ :params samples: (batch_size, num_channels, num_samples) """ batch_size, _, num_samples = samples.shape # (batch_size, num_samples) RMS-normalized background noise audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) self.transform_parameters["background"] = torch.stack( [self.random_background(audio, num_samples) for _ in range(batch_size)] ) # (batch_size, ) SNRs if self.min_snr_in_db == self.max_snr_in_db: self.transform_parameters["snr_in_db"] = torch.full( size=(batch_size,), fill_value=self.min_snr_in_db, dtype=torch.float32, device=samples.device, ) else: snr_distribution = torch.distributions.Uniform( low=torch.tensor( self.min_snr_in_db, dtype=torch.float32, device=samples.device ), high=torch.tensor( self.max_snr_in_db, dtype=torch.float32, device=samples.device ), validate_args=True, ) self.transform_parameters["snr_in_db"] = snr_distribution.sample( sample_shape=(batch_size,) ) def apply_transform( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ) -> ObjectDict: batch_size, num_channels, num_samples = samples.shape # (batch_size, num_samples) background = self.transform_parameters["background"].to(samples.device) # (batch_size, num_channels) background_rms = calculate_rms(samples) / ( 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) return ObjectDict( samples=samples + background_rms.unsqueeze(-1) * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), sample_rate=sample_rate, targets=targets, target_rate=target_rate, )
Memory