import torch import typing from typing import Optional from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform from ..utils.object_dict import ObjectDict class PeakNormalization(BaseWaveformTransform): """ Apply a constant amount of gain, so that highest signal level present in each audio snippet in the batch becomes 0 dBFS, i.e. the loudest level allowed if all samples must be between -1 and 1. This transform has an alternative mode (apply_to="only_too_loud_sounds") where it only applies to audio snippets that have extreme values outside the [-1, 1] range. This is useful for avoiding digital clipping in audio that is too loud, while leaving other audio untouched. """ supported_modes = {"per_batch", "per_example", "per_channel"} supports_multichannel = True requires_sample_rate = False supports_target = True requires_target = False def __init__( self, apply_to="all", mode: str = "per_example", p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, target_rate: typing.Optional[int] = None, output_type: Optional[str] = None, ): super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) assert apply_to in ("all", "only_too_loud_sounds") self.apply_to = apply_to def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): # Compute the most extreme value of each multichannel audio snippet in the batch most_extreme_values, _ = torch.max(torch.abs(samples), dim=-1) most_extreme_values, _ = torch.max(most_extreme_values, dim=-1) if self.apply_to == "all": # Avoid division by zero self.transform_parameters["selector"] = ( most_extreme_values > 0.0 ) # type: torch.BoolTensor elif self.apply_to == "only_too_loud_sounds": # Apply peak normalization only to audio examples with # values outside the [-1, 1] range self.transform_parameters["selector"] = ( most_extreme_values > 1.0 ) # type: torch.BoolTensor else: raise Exception("Unknown value of apply_to in PeakNormalization instance!") if self.transform_parameters["selector"].any(): self.transform_parameters["divisors"] = torch.reshape( most_extreme_values[self.transform_parameters["selector"]], (-1, 1, 1) ) def apply_transform( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ) -> ObjectDict: if "divisors" in self.transform_parameters: samples[self.transform_parameters["selector"]] /= self.transform_parameters[ "divisors" ] return ObjectDict( samples=samples, sample_rate=sample_rate, targets=targets, target_rate=target_rate, )
Memory