import torch from typing import Optional from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform from ..utils.object_dict import ObjectDict class Padding(BaseWaveformTransform): supported_modes = {"per_batch", "per_example", "per_channel"} supports_multichannel = True requires_sample_rate = False supports_target = True requires_target = False def __init__( self, min_fraction=0.1, max_fraction=0.5, pad_section="end", mode="per_batch", p=0.5, p_mode: Optional[str] = None, sample_rate: Optional[int] = None, target_rate: Optional[int] = None, output_type: Optional[str] = None, ): super().__init__( mode=mode, p=p, p_mode=p_mode, sample_rate=sample_rate, target_rate=target_rate, output_type=output_type, ) self.min_fraction = min_fraction self.max_fraction = max_fraction self.pad_section = pad_section if not self.min_fraction >= 0.0: raise ValueError("minimum fraction should be greater than zero.") if self.min_fraction > self.max_fraction: raise ValueError( "minimum fraction should be less than or equal to maximum fraction." ) assert self.pad_section in ( "start", "end", ), 'pad_section must be "start" or "end"' def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): input_length = samples.shape[-1] self.transform_parameters["pad_length"] = torch.randint( int(input_length * self.min_fraction), int(input_length * self.max_fraction), (samples.shape[0],), ) def apply_transform( self, samples: Tensor, sample_rate: Optional[int] = None, targets: Optional[int] = None, target_rate: Optional[int] = None, ) -> ObjectDict: for i, index in enumerate(self.transform_parameters["pad_length"]): if self.pad_section == "start": samples[i, :, :index] = 0.0 else: samples[i, :, -index:] = 0.0 return ObjectDict( samples=samples, sample_rate=sample_rate, targets=targets, target_rate=target_rate, )
Memory