# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Optional, Sequence, Tuple, Union from torch import Tensor from typing_extensions import Literal from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, MultilabelPrecisionRecallCurve, ) from torchmetrics.functional.classification.precision_fixed_recall import _precision_at_recall from torchmetrics.functional.classification.recall_fixed_precision import ( _binary_recall_at_fixed_precision_arg_validation, _binary_recall_at_fixed_precision_compute, _multiclass_recall_at_fixed_precision_arg_compute, _multiclass_recall_at_fixed_precision_arg_validation, _multilabel_recall_at_fixed_precision_arg_compute, _multilabel_recall_at_fixed_precision_arg_validation, ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = [ "BinaryPrecisionAtFixedRecall.plot", "MulticlassPrecisionAtFixedRecall.plot", "MultilabelPrecisionAtFixedRecall.plot", ] class BinaryPrecisionAtFixedRecall(BinaryPrecisionRecallCurve): r"""Compute the highest possible precision value given the minimum recall thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. .. note:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: - ``precision`` (:class:`~torch.Tensor`): A scalar tensor with the maximum precision for the given recall level - ``threshold`` (:class:`~torch.Tensor`): A scalar tensor with the corresponding threshold level .. note:: The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds})` (constant memory). Args: min_recall: float value specifying minimum recall threshold. thresholds: Can be one of: - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: >>> from torch import tensor >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=None) >>> metric(preds, target) (tensor(0.6667), tensor(0.5000)) >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=5) >>> metric(preds, target) (tensor(0.6667), tensor(0.5000)) """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 def __init__( self, min_recall: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(thresholds, ignore_index, validate_args=False, **kwargs) if validate_args: _binary_recall_at_fixed_precision_arg_validation(min_recall, thresholds, ignore_index) self.validate_args = validate_args self.min_recall = min_recall def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute( state, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure object and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5) >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(10), randint(2,(10,)))[0]) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute()[0] # by default we select the maximum recall value to plot return self._plot(val, ax) class MulticlassPrecisionAtFixedRecall(MulticlassPrecisionRecallCurve): r"""Compute the highest possible precision value given the minimum recall thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). .. note:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: - ``precision`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum precision for the given recall level per class - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class .. note:: The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). Args: num_classes: Integer specifing the number of classes min_recall: float value specifying minimum recall threshold. thresholds: Can be one of: - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: >>> from torch import tensor >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=None) >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) >>> mcrafp = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=5) >>> mcrafp(preds, target) # doctest: +NORMALIZE_WHITESPACE (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06])) """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" def __init__( self, num_classes: int, min_recall: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__( num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs ) if validate_args: _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_recall, thresholds, ignore_index) self.validate_args = validate_args self.min_recall = min_recall def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( state, self.num_classes, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure object and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5) >>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5) >>> values = [] >>> for _ in range(20): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0]) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute()[0] # by default we select the maximum recall value to plot return self._plot(val, ax) class MultilabelPrecisionAtFixedRecall(MultilabelPrecisionRecallCurve): r"""Compute the highest possible precision value given the minimum recall thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. .. note:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: - ``precision`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum precision for the given recall level per class - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class .. note:: The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). Args: num_labels: Integer specifing the number of labels min_recall: float value specifying minimum recall threshold. thresholds: Can be one of: - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: >>> from torch import tensor >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=None) >>> metric(preds, target) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500])) >>> mlrafp = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=5) >>> mlrafp(preds, target) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500])) """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 plot_legend_name: str = "Label" def __init__( self, num_labels: int, min_recall: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__( num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs ) if validate_args: _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_recall, thresholds, ignore_index) self.validate_args = validate_args self.min_recall = min_recall def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( state, self.num_labels, self.thresholds, self.ignore_index, self.min_recall, reduce_fn=_precision_at_recall ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure object and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default .. plot:: :scale: 75 >>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3), randint(2, (20, 3)))[0]) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute()[0] # by default we select the maximum recall value to plot return self._plot(val, ax) class PrecisionAtFixedRecall(_ClassificationTaskWrapper): r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of :mod:`BinaryPrecisionAtFixedRecall`, :func:`MulticlassPrecisionAtFixedRecall` and :func:`MultilabelPrecisionAtFixedRecall` for the specific details of each argument influence and examples. """ def __new__( # type: ignore[misc] cls, task: Literal["binary", "multiclass", "multilabel"], min_recall: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinaryPrecisionAtFixedRecall(min_recall, thresholds, ignore_index, validate_args, **kwargs) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") return MulticlassPrecisionAtFixedRecall( num_classes, min_recall, thresholds, ignore_index, validate_args, **kwargs ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelPrecisionAtFixedRecall( num_labels, min_recall, thresholds, ignore_index, validate_args, **kwargs ) raise ValueError(f"Task {task} not supported!")
Memory