import torch
import typing
from typing import Optional
from torch import Tensor
from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict
class PeakNormalization(BaseWaveformTransform):
"""
Apply a constant amount of gain, so that highest signal level present in each audio snippet
in the batch becomes 0 dBFS, i.e. the loudest level allowed if all samples must be between
-1 and 1.
This transform has an alternative mode (apply_to="only_too_loud_sounds") where it only
applies to audio snippets that have extreme values outside the [-1, 1] range. This is useful
for avoiding digital clipping in audio that is too loud, while leaving other audio
untouched.
"""
supported_modes = {"per_batch", "per_example", "per_channel"}
supports_multichannel = True
requires_sample_rate = False
supports_target = True
requires_target = False
def __init__(
self,
apply_to="all",
mode: str = "per_example",
p: float = 0.5,
p_mode: typing.Optional[str] = None,
sample_rate: typing.Optional[int] = None,
target_rate: typing.Optional[int] = None,
output_type: Optional[str] = None,
):
super().__init__(
mode=mode,
p=p,
p_mode=p_mode,
sample_rate=sample_rate,
target_rate=target_rate,
output_type=output_type,
)
assert apply_to in ("all", "only_too_loud_sounds")
self.apply_to = apply_to
def randomize_parameters(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
):
# Compute the most extreme value of each multichannel audio snippet in the batch
most_extreme_values, _ = torch.max(torch.abs(samples), dim=-1)
most_extreme_values, _ = torch.max(most_extreme_values, dim=-1)
if self.apply_to == "all":
# Avoid division by zero
self.transform_parameters["selector"] = (
most_extreme_values > 0.0
) # type: torch.BoolTensor
elif self.apply_to == "only_too_loud_sounds":
# Apply peak normalization only to audio examples with
# values outside the [-1, 1] range
self.transform_parameters["selector"] = (
most_extreme_values > 1.0
) # type: torch.BoolTensor
else:
raise Exception("Unknown value of apply_to in PeakNormalization instance!")
if self.transform_parameters["selector"].any():
self.transform_parameters["divisors"] = torch.reshape(
most_extreme_values[self.transform_parameters["selector"]], (-1, 1, 1)
)
def apply_transform(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
) -> ObjectDict:
if "divisors" in self.transform_parameters:
samples[self.transform_parameters["selector"]] /= self.transform_parameters[
"divisors"
]
return ObjectDict(
samples=samples,
sample_rate=sample_rate,
targets=targets,
target_rate=target_rate,
)