from typing import Optional import torch from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import calculate_rms from ..utils.io import Audio from ..utils.object_dict import ObjectDict class Mix(BaseWaveformTransform): """ Create a new sample by mixing it with another random sample from the same batch Signal-to-noise ratio (where "noise" is the second random sample) is selected randomly between `min_snr_in_db` and `max_snr_in_db`. `mix_target` controls how resulting targets are generated. It can be one of "original" (targets are those of the original sample) or "union" (targets are the union of original and overlapping targets) """ supported_modes = {"per_example", "per_channel"} supports_multichannel = True requires_sample_rate = False supports_target = True requires_target = False def __init__( self, min_snr_in_db: float = 0.0, max_snr_in_db: float = 5.0, mix_target: str = "union", mode: str = "per_example", p: float = 0.5, p_mode: str = None, sample_rate: int = None, target_rate: int = None, output_type: Optional[str] = None, ): super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) self.min_snr_in_db = min_snr_in_db self.max_snr_in_db = max_snr_in_db if self.min_snr_in_db > self.max_snr_in_db: raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") self.mix_target = mix_target if mix_target == "original": self._mix_target = lambda target, background_target, snr: target elif mix_target == "union": self._mix_target = lambda target, background_target, snr: torch.maximum( target, background_target ) else: raise ValueError("mix_target must be one of 'original' or 'union'.") def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): batch_size, num_channels, num_samples = samples.shape snr_distribution = torch.distributions.Uniform( low=torch.tensor( self.min_snr_in_db, dtype=torch.float32, device=samples.device, ), high=torch.tensor( self.max_snr_in_db, dtype=torch.float32, device=samples.device, ), validate_args=True, ) # randomize SNRs self.transform_parameters["snr_in_db"] = snr_distribution.sample( sample_shape=(batch_size,) ) # randomize index of second sample self.transform_parameters["sample_idx"] = torch.randint( 0, batch_size, (batch_size,), device=samples.device, ) def apply_transform( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ) -> ObjectDict: snr = self.transform_parameters["snr_in_db"] idx = self.transform_parameters["sample_idx"] background_samples = Audio.rms_normalize(samples[idx]) background_rms = calculate_rms(samples) / (10 ** (snr.unsqueeze(dim=-1) / 20)) mixed_samples = samples + background_rms.unsqueeze(-1) * background_samples if targets is None: mixed_targets = None else: background_targets = targets[idx] mixed_targets = self._mix_target(targets, background_targets, snr) return ObjectDict( samples=mixed_samples, sample_rate=sample_rate, targets=mixed_targets, target_rate=target_rate, )
Memory