from typing import Optional
from torch import Tensor
from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict
class Identity(BaseWaveformTransform):
"""
This transform returns the input unchanged. It can be used for simplifying the code
in cases where data augmentation should be disabled.
"""
supported_modes = {"per_batch", "per_example", "per_channel"}
supports_multichannel = True
requires_sample_rate = False
supports_target = True
requires_target = False
def apply_transform(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
) -> ObjectDict:
return ObjectDict(
samples=samples,
sample_rate=sample_rate,
targets=targets,
target_rate=target_rate,
)