# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from dataclasses import dataclass from typing import Any, Iterable, List, Optional, Tuple from .image_processing_utils import BaseImageProcessor from .utils.import_utils import is_torch_available, is_torchvision_available if is_torchvision_available(): from torchvision.transforms import Compose if is_torch_available(): import torch @dataclass(frozen=True) class SizeDict: """ Hashable dictionary to store image size information. """ height: int = None width: int = None longest_edge: int = None shortest_edge: int = None max_height: int = None max_width: int = None def __getitem__(self, key): if hasattr(self, key): return getattr(self, key) raise KeyError(f"Key {key} not found in SizeDict.") class BaseImageProcessorFast(BaseImageProcessor): _transform_params = None def _build_transforms(self, **kwargs) -> "Compose": """ Given the input settings e.g. do_resize, build the image transforms. """ raise NotImplementedError def _validate_params(self, **kwargs) -> None: for k, v in kwargs.items(): if k not in self._transform_params: raise ValueError(f"Invalid transform parameter {k}={v}.") @functools.lru_cache(maxsize=1) def get_transforms(self, **kwargs) -> "Compose": self._validate_params(**kwargs) return self._build_transforms(**kwargs) def to_dict(self): encoder_dict = super().to_dict() encoder_dict.pop("_transform_params", None) return encoder_dict def get_image_size_for_max_height_width( image_size: Tuple[int, int], max_height: int, max_width: int, ) -> Tuple[int, int]: """ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. Important, even if image_height < max_height and image_width < max_width, the image will be resized to at least one of the edges be equal to max_height or max_width. For example: - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) Args: image_size (`Tuple[int, int]`): The image to resize. max_height (`int`): The maximum allowed height. max_width (`int`): The maximum allowed width. """ height, width = image_size height_scale = max_height / height width_scale = max_width / width min_scale = min(height_scale, width_scale) new_height = int(height * min_scale) new_width = int(width * min_scale) return new_height, new_width def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": """ Squeezes a tensor, but only if the axis specified has dim 1. """ if axis is None: return tensor.squeeze() try: return tensor.squeeze(axis=axis) except ValueError: return tensor def max_across_indices(values: Iterable[Any]) -> List[Any]: """ Return the maximum value across all indices of an iterable of values. """ return [max(values_i) for values_i in zip(*values)] def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]: """ Get the maximum height and width across all images in a batch. """ _, max_height, max_width = max_across_indices([img.shape for img in images]) return (max_height, max_width)
Memory