# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from collections.abc import Iterator from typing import Any, Optional, Union import torch from lightning_utilities import WarningCache import pytorch_lightning as pl from lightning_fabric.utilities import move_data_to_device from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from pytorch_lightning.loops.loop import _Loop from pytorch_lightning.loops.progress import _Progress from pytorch_lightning.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper from pytorch_lightning.strategies.launchers import _MultiProcessingLauncher from pytorch_lightning.trainer import call from pytorch_lightning.trainer.connectors.data_connector import ( _check_dataloader_iterable, _DataLoaderSource, _parse_num_batches, _process_dataloader, _request_dataloader, ) from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.combined_loader import CombinedLoader from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import _ModuleMode from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _PREDICT_OUTPUT class _PredictionLoop(_Loop): """Top-level loop where prediction starts.""" def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: super().__init__(trainer) self.inference_mode = inference_mode # dataloaders x batches x samples. used by PredictionWriter self.epoch_batch_indices: list[list[list[int]]] = [] self.current_batch_indices: list[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders self.max_batches: list[Union[int, float]] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._results = None # for `trainer._results` access self._predictions: list[list[Any]] = [] # dataloaders x batches self._return_predictions = False self._module_mode = _ModuleMode() @property def return_predictions(self) -> bool: """Whether to return the predictions or not.""" return self._return_predictions @return_predictions.setter def return_predictions(self, return_predictions: Optional[bool] = None) -> None: # Strategies that spawn or fork don't support returning predictions return_supported = not isinstance(self.trainer.strategy.launcher, _MultiProcessingLauncher) if return_predictions and not return_supported: raise MisconfigurationException( "`return_predictions` should be set to `False` when using the strategies that spawn or fork." f" Found {return_predictions} with strategy {type(self.trainer.strategy)}." ) # For strategies that support it, `return_predictions` is True by default unless user decide otherwise. self._return_predictions = return_supported if return_predictions is None else return_predictions @property def predictions(self) -> list[Any]: """The cached predictions.""" if self._predictions == []: return self._predictions return self._predictions[0] if self.num_dataloaders == 1 else self._predictions @property def num_dataloaders(self) -> int: """Returns the number of prediction dataloaders.""" combined_loader = self._combined_loader assert combined_loader is not None return len(combined_loader.flattened) @property def skip(self) -> bool: return sum(self.max_batches) == 0 @_no_grad_context def run(self) -> Optional[_PREDICT_OUTPUT]: self.setup_data() if self.skip: return None self.reset() self.on_run_start() data_fetcher = self._data_fetcher assert data_fetcher is not None while True: try: if isinstance(data_fetcher, _DataLoaderIterDataFetcher): dataloader_iter = next(data_fetcher) # hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting batch = data_fetcher._batch batch_idx = data_fetcher._batch_idx dataloader_idx = data_fetcher._dataloader_idx else: dataloader_iter = None batch, batch_idx, dataloader_idx = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done # run step hooks self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) except StopIteration: # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support break finally: self._restarting = False return self.on_run_end() def setup_data(self) -> None: trainer = self.trainer # a default `predict_step` exists in the LightningModule, so no need to check if it's overridden if trainer.limit_predict_batches == 0: return source = self._data_source dataloaders = _request_dataloader(source) trainer.strategy.barrier("predict_dataloader()") if not isinstance(dataloaders, CombinedLoader): combined_loader = CombinedLoader(dataloaders, "sequential") else: combined_loader = dataloaders allow_zero_length = trainer.lightning_module.allow_zero_length_dataloader_with_multiple_devices if trainer.datamodule is not None: allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices trainer_fn = TrainerFn.PREDICTING stage = RunningStage.PREDICTING dataloaders = [] self.max_batches = [] for dl in combined_loader.flattened: _check_dataloader_iterable(dl, source, trainer_fn) dl = _process_dataloader(trainer, trainer_fn, stage, dl) dataloaders.append(dl) # determine number of batches length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") num_batches = _parse_num_batches(stage, length, trainer.limit_predict_batches) self.max_batches.append(num_batches) combined_loader.flattened = dataloaders self._combined_loader = combined_loader def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.batch_progress.reset_on_run() assert self.trainer.state.stage is not None data_fetcher = _select_data_fetcher(self.trainer, self.trainer.state.stage) combined_loader = self._combined_loader assert combined_loader is not None if combined_loader._mode != "sequential": raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.') # set the per-dataloader limits combined_loader.limits = self.max_batches data_fetcher.setup(combined_loader) iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready data_fetcher._start_profiler = self._on_before_fetch data_fetcher._stop_profiler = self._on_after_fetch self._data_fetcher = data_fetcher num_dataloaders = self.num_dataloaders self.epoch_batch_indices = [[] for _ in range(num_dataloaders)] self._predictions = [[] for _ in range(num_dataloaders)] def on_run_start(self) -> None: """Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks.""" self._verify_dataloader_idx_requirement() self._on_predict_model_eval() self._on_predict_start() self._on_predict_epoch_start() def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() self._on_predict_model_train() return results def teardown(self) -> None: if self._data_fetcher is not None: self._data_fetcher.teardown() self._data_fetcher = None def _predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int, dataloader_iter: Optional[Iterator] ) -> None: """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to it. Args: batch: the current batch to run the prediction on batch_idx: The index of the current batch. dataloader_idx: the index of the dataloader producing the current batch. dataloader_iter: The iterator if using this step flavor. """ trainer = self.trainer data_fetcher = self._data_fetcher assert data_fetcher is not None if not (using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher)): batch = trainer.precision_plugin.convert_input(batch) batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() any_on_epoch = ( self._store_data_for_prediction_writer(batch_idx, dataloader_idx) if not using_dataloader_iter else False ) # the `_step` methods don't take a batch_idx when `dataloader_iter` is used, but all other hooks still do, # so we need different kwargs hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) call._call_callback_hooks(trainer, "on_predict_batch_start", *hook_kwargs.values()) call._call_lightning_module_hook(trainer, "on_predict_batch_start", *hook_kwargs.values()) self.batch_progress.increment_started() # configure step_kwargs step_args = ( self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step") if not using_dataloader_iter else (dataloader_iter,) ) predictions = call._call_strategy_hook(trainer, "predict_step", *step_args) if predictions is None: self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") self.batch_progress.increment_processed() if using_dataloader_iter: # update the hook kwargs now that the step method might have consumed the iterator batch = data_fetcher._batch batch_idx = data_fetcher._batch_idx dataloader_idx = data_fetcher._dataloader_idx hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) self.batch_progress.increment_completed() if self._return_predictions or any_on_epoch: self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: """Assembles the keyword arguments for the ``predict_step`` Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch. dataloader_idx: the index of the dataloader producing the current batch. None if not multiple dataloaders in sequential mode. Returns: the dictionary containing all the keyboard arguments for the predict step """ step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) if dataloader_idx is not None: step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs def _build_step_args_from_hook_kwargs(self, hook_kwargs: OrderedDict, step_hook_name: str) -> tuple: """Helper method to build args for `predict_step`.""" kwargs = hook_kwargs.copy() step_hook_fx = getattr(self.trainer.lightning_module, step_hook_name) if not is_param_in_hook_signature(step_hook_fx, "batch_idx", min_args=2): kwargs.pop("batch_idx", None) return tuple(kwargs.values()) def _get_batch_indices(self, dataloader: object) -> list[list[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) if not isinstance(batch_sampler, _IndexBatchSamplerWrapper): self._warning_cache.warn( f"Couldn't infer the batch indices fetched from your dataloader: `{type(dataloader).__name__}`" ) return [] return batch_sampler.seen_batch_indices def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) -> bool: prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)] any_on_epoch = any(cb.interval.on_epoch for cb in prediction_writers) any_on_batch = any(cb.interval.on_batch for cb in prediction_writers) if any_on_batch or any_on_epoch: combined_loader = self._combined_loader assert combined_loader is not None dataloader = combined_loader.flattened[dataloader_idx] batch_indices = self._get_batch_indices(dataloader) if not batch_indices: # this is only available with `_IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is # reached, it's likely because a non-DataLoader was passed return any_on_epoch batch_indices = batch_indices[batch_idx] if any_on_epoch: self.epoch_batch_indices[dataloader_idx].append(batch_indices) if any_on_batch: self.current_batch_indices = batch_indices return any_on_epoch def _on_before_fetch(self) -> None: self.trainer.profiler.start(f"[{type(self).__name__}].predict_next") def _on_after_fetch(self) -> None: # the dataloader_idx cannot be easily included here because it might be different from the index used on # profiler start, since the `__next__` call might use a different iterator self.trainer.profiler.stop(f"[{type(self).__name__}].predict_next") def _on_predict_start(self) -> None: """Calls ``on_predict_start`` hooks.""" trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_start") call._call_lightning_module_hook(trainer, "on_predict_start") call._call_strategy_hook(trainer, "on_predict_start") def _on_predict_model_eval(self) -> None: self._module_mode.capture(self.trainer.lightning_module) call._call_lightning_module_hook(self.trainer, "on_predict_model_eval") def _on_predict_model_train(self) -> None: self._module_mode.restore(self.trainer.lightning_module) def _on_predict_epoch_start(self) -> None: """Calls ``on_predict_epoch_start`` hooks.""" trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_epoch_start") call._call_lightning_module_hook(trainer, "on_predict_epoch_start") def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` hook. Returns: the results for all dataloaders """ trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_epoch_end") call._call_lightning_module_hook(trainer, "on_predict_epoch_end") if self.return_predictions: return self.predictions return None def _on_predict_end(self) -> None: """Resets previous gradient status and calls ``on_predict_end`` hook.""" if not self.return_predictions: self._predictions = [] self.epoch_batch_indices = [] trainer = self.trainer # hook call._call_callback_hooks(trainer, "on_predict_end") call._call_lightning_module_hook(trainer, "on_predict_end") call._call_strategy_hook(trainer, "on_predict_end") def _verify_dataloader_idx_requirement(self) -> None: trainer = self.trainer assert self._combined_loader is not None _verify_dataloader_idx_requirement( ("predict_step",), self._combined_loader._mode == "sequential" and self.num_dataloaders > 1 and not isinstance(self._data_fetcher, _DataLoaderIterDataFetcher), RunningStage.PREDICTING, trainer.lightning_module, ) _verify_dataloader_idx_requirement( ("on_predict_batch_start", "on_predict_batch_end"), self._combined_loader._mode == "sequential" and self.num_dataloaders > 1, RunningStage.PREDICTING, trainer.lightning_module, )
Memory