from random import choices from torch import Tensor from typing import Optional from torch_pitch_shift import pitch_shift, get_fast_shifts, semitones_to_ratio from ..core.transforms_interface import BaseWaveformTransform from ..utils.object_dict import ObjectDict class PitchShift(BaseWaveformTransform): """ Pitch-shift sounds up or down without changing the tempo. """ supported_modes = {"per_batch", "per_example", "per_channel"} supports_multichannel = True requires_sample_rate = True supports_target = True requires_target = False def __init__( self, min_transpose_semitones: float = -4.0, max_transpose_semitones: float = 4.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, sample_rate: int = None, target_rate: int = None, output_type: Optional[str] = None, ): """ :param sample_rate: :param min_transpose_semitones: Minimum pitch shift transposition in semitones (default -4.0) :param max_transpose_semitones: Maximum pitch shift transposition in semitones (default +4.0) :param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``. :param p: :param p_mode: :param target_rate: """ super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) if min_transpose_semitones > max_transpose_semitones: raise ValueError("max_transpose_semitones must be > min_transpose_semitones") if not sample_rate: raise ValueError("sample_rate is invalid.") self._sample_rate = sample_rate self._fast_shifts = get_fast_shifts( sample_rate, lambda x: x >= semitones_to_ratio(min_transpose_semitones) and x <= semitones_to_ratio(max_transpose_semitones) and x != 1, ) if not len(self._fast_shifts): raise ValueError( "No fast pitch-shift ratios could be computed for the given sample rate and transpose range." ) self._mode = mode def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): """ :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ batch_size, num_channels, num_samples = samples.shape if self._mode == "per_example": self.transform_parameters["transpositions"] = choices( self._fast_shifts, k=batch_size ) elif self._mode == "per_channel": self.transform_parameters["transpositions"] = list( zip( *[ choices(self._fast_shifts, k=batch_size) for i in range(num_channels) ] ) ) elif self._mode == "per_batch": self.transform_parameters["transpositions"] = choices(self._fast_shifts, k=1) def apply_transform( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ) -> ObjectDict: """ :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ batch_size, num_channels, num_samples = samples.shape if sample_rate is not None and sample_rate != self._sample_rate: raise ValueError( "sample_rate must match the value of sample_rate " + "passed into the PitchShift constructor" ) sample_rate = self.sample_rate if self._mode == "per_example": for i in range(batch_size): samples[i, ...] = pitch_shift( samples[i][None], self.transform_parameters["transpositions"][i], sample_rate, )[0] elif self._mode == "per_channel": for i in range(batch_size): for j in range(num_channels): samples[i, j, ...] = pitch_shift( samples[i][j][None][None], self.transform_parameters["transpositions"][i][j], sample_rate, )[0][0] elif self._mode == "per_batch": samples = pitch_shift( samples, self.transform_parameters["transpositions"][0], sample_rate ) return ObjectDict( samples=samples, sample_rate=sample_rate, targets=targets, target_rate=target_rate, )
Memory