from random import choices
from torch import Tensor
from typing import Optional
from torch_pitch_shift import pitch_shift, get_fast_shifts, semitones_to_ratio
from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict
class PitchShift(BaseWaveformTransform):
"""
Pitch-shift sounds up or down without changing the tempo.
"""
supported_modes = {"per_batch", "per_example", "per_channel"}
supports_multichannel = True
requires_sample_rate = True
supports_target = True
requires_target = False
def __init__(
self,
min_transpose_semitones: float = -4.0,
max_transpose_semitones: float = 4.0,
mode: str = "per_example",
p: float = 0.5,
p_mode: str = None,
sample_rate: int = None,
target_rate: int = None,
output_type: Optional[str] = None,
):
"""
:param sample_rate:
:param min_transpose_semitones: Minimum pitch shift transposition in semitones (default -4.0)
:param max_transpose_semitones: Maximum pitch shift transposition in semitones (default +4.0)
:param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``.
:param p:
:param p_mode:
:param target_rate:
"""
super().__init__(
mode=mode,
p=p,
p_mode=p_mode,
sample_rate=sample_rate,
target_rate=target_rate,
output_type=output_type,
)
if min_transpose_semitones > max_transpose_semitones:
raise ValueError("max_transpose_semitones must be > min_transpose_semitones")
if not sample_rate:
raise ValueError("sample_rate is invalid.")
self._sample_rate = sample_rate
self._fast_shifts = get_fast_shifts(
sample_rate,
lambda x: x >= semitones_to_ratio(min_transpose_semitones)
and x <= semitones_to_ratio(max_transpose_semitones)
and x != 1,
)
if not len(self._fast_shifts):
raise ValueError(
"No fast pitch-shift ratios could be computed for the given sample rate and transpose range."
)
self._mode = mode
def randomize_parameters(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
):
"""
:param samples: (batch_size, num_channels, num_samples)
:param sample_rate:
"""
batch_size, num_channels, num_samples = samples.shape
if self._mode == "per_example":
self.transform_parameters["transpositions"] = choices(
self._fast_shifts, k=batch_size
)
elif self._mode == "per_channel":
self.transform_parameters["transpositions"] = list(
zip(
*[
choices(self._fast_shifts, k=batch_size)
for i in range(num_channels)
]
)
)
elif self._mode == "per_batch":
self.transform_parameters["transpositions"] = choices(self._fast_shifts, k=1)
def apply_transform(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
) -> ObjectDict:
"""
:param samples: (batch_size, num_channels, num_samples)
:param sample_rate:
"""
batch_size, num_channels, num_samples = samples.shape
if sample_rate is not None and sample_rate != self._sample_rate:
raise ValueError(
"sample_rate must match the value of sample_rate "
+ "passed into the PitchShift constructor"
)
sample_rate = self.sample_rate
if self._mode == "per_example":
for i in range(batch_size):
samples[i, ...] = pitch_shift(
samples[i][None],
self.transform_parameters["transpositions"][i],
sample_rate,
)[0]
elif self._mode == "per_channel":
for i in range(batch_size):
for j in range(num_channels):
samples[i, j, ...] = pitch_shift(
samples[i][j][None][None],
self.transform_parameters["transpositions"][i][j],
sample_rate,
)[0][0]
elif self._mode == "per_batch":
samples = pitch_shift(
samples, self.transform_parameters["transpositions"][0], sample_rate
)
return ObjectDict(
samples=samples,
sample_rate=sample_rate,
targets=targets,
target_rate=target_rate,
)