import sys
from typing import TYPE_CHECKING, Any, Literal, Optional
import pytorch_lightning as pl
from lightning_fabric.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.plugins.precision import (
BitsandbytesPrecision,
DeepSpeedPrecision,
DoublePrecision,
FSDPPrecision,
HalfPrecision,
MixedPrecision,
Precision,
TransformerEnginePrecision,
XLAPrecision,
)
if TYPE_CHECKING:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
def _patch_sys_modules() -> None:
sys.modules["pytorch_lightning.plugins.precision.precision_plugin"] = sys.modules[
"pytorch_lightning.plugins.precision.precision"
]
class FSDPMixedPrecisionPlugin(FSDPPrecision):
"""AMP for Fully Sharded Data Parallel (FSDP) Training.
.. deprecated:: Use :class:`FSDPPrecision` instead.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
"""
def __init__(
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
rank_zero_deprecation(
f"The `{type(self).__name__}` is deprecated."
" Use `pytorch_lightning.plugins.precision.FSDPPrecision` instead."
)
super().__init__(precision=precision, scaler=scaler)
def _create_class(deprecated_name: str, new_class: type) -> type:
def init(self: type, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
f"The `{deprecated_name}` is deprecated."
f" Use `pytorch_lightning.plugins.precision.{new_class.__name__}` instead."
)
new_class.__init__(self, *args, **kwargs) # type: ignore[misc]
return type(deprecated_name, (new_class,), {"__init__": init})
def _patch_classes() -> None:
classes_map = (
# module name, old name, new class
("bitsandbytes", "BitsandbytesPrecisionPlugin", BitsandbytesPrecision),
("deepspeed", "DeepSpeedPrecisionPlugin", DeepSpeedPrecision),
("double", "DoublePrecisionPlugin", DoublePrecision),
("fsdp", "FSDPPrecisionPlugin", FSDPPrecision),
("fsdp", "FSDPMixedPrecisionPlugin", FSDPPrecision),
("half", "HalfPrecisionPlugin", HalfPrecision),
("amp", "MixedPrecisionPlugin", MixedPrecision),
("precision", "PrecisionPlugin", Precision),
("transformer_engine", "TransformerEnginePrecisionPlugin", TransformerEnginePrecision),
("xla", "XLAPrecisionPlugin", XLAPrecision),
)
for module_name, deprecated_name, new_class in classes_map:
deprecated_class = _create_class(deprecated_name, new_class)
setattr(getattr(pl.plugins.precision, module_name), deprecated_name, deprecated_class)
setattr(pl.plugins.precision, deprecated_name, deprecated_class)
setattr(pl.plugins, deprecated_name, deprecated_class)
# special treatment for `FSDPMixedPrecisionPlugin` because it has a different signature
setattr(pl.plugins.precision.fsdp, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
setattr(pl.plugins.precision, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
setattr(pl.plugins, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
_patch_sys_modules()
_patch_classes()