# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import inspect import logging import os from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn from typing_extensions import Concatenate, ParamSpec import pytorch_lightning as pl _log = logging.getLogger(__name__) def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[type[object]] = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False if parent is None: if isinstance(instance, pl.LightningModule): parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule elif isinstance(instance, pl.Callback): parent = pl.Callback if parent is None: _check_mixed_imports(instance) raise ValueError("Expected a parent") from lightning_utilities.core.overrides import is_overridden as _is_overridden return _is_overridden(method_name, instance, parent) def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module: from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError(str(_TORCHVISION_AVAILABLE)) from torchvision import models torchvision_greater_equal_0_14 = RequirementCache("torchvision>=0.14.0") # TODO: deprecate this function when 0.14 is the minimum supported torchvision if torchvision_greater_equal_0_14: return models.get_model(model_name, **kwargs) return getattr(models, model_name)(**kwargs) class _ModuleMode: """Captures the ``nn.Module.training`` (bool) mode of every submodule, and allows it to be restored later on.""" def __init__(self) -> None: self.mode: dict[str, bool] = {} def capture(self, module: nn.Module) -> None: self.mode.clear() for name, mod in module.named_modules(): self.mode[name] = mod.training def restore(self, module: nn.Module) -> None: for name, mod in module.named_modules(): if name not in self.mode: _log.debug( f"Restoring training mode on module '{name}' not possible, it was never captured." f" Is your module structure changing?" ) continue mod.training = self.mode[name] def _check_mixed_imports(instance: object) -> None: old, new = "pytorch_" + "lightning", "lightning." + "pytorch" klass = type(instance) module = klass.__module__ if module.startswith(old) and __name__.startswith(new): pass elif module.startswith(new) and __name__.startswith(old): old, new = new, old else: return raise TypeError( f"You passed a `{old}` object ({type(instance).__qualname__}) to a `{new}`" " Trainer. Please switch to a single import style." ) _T = TypeVar("_T") # type of the method owner _P = ParamSpec("_P") # parameters of the decorated method _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: self.method = method def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: # Workaround for https://github.com/pytorch/pytorch/issues/67146 is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) if instance is not None and not is_scripting: raise TypeError( f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." " Please call it on the class type and make sure the return value is used." ) return self.method(cls, *args, **kwargs) return wrapper if TYPE_CHECKING: # trick static type checkers into thinking it's a @classmethod # https://github.com/microsoft/pyright/issues/5865 _restricted_classmethod = classmethod else: _restricted_classmethod = _restricted_classmethod_impl
Memory