import torch
import torch.nn.functional as F
from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction
class SmoothAPLoss(BaseMetricLossFunction):
"""
Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
"""
def __init__(self, temperature=0.01, **kwargs):
super().__init__(**kwargs)
c_f.assert_distance_type(self, CosineSimilarity)
self.temperature = temperature
def get_default_distance(self):
return CosineSimilarity()
# Implementation is based on the original repository:
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
# The loss expects labels such that there is the same number of elements for each class
# The number of classes is not important, nor their order, but the number of elements must be the same, eg.
#
# The following label is valid:
# [ A,A,A, B,B,B, C,C,C ]
# The following label is NOT valid:
# [ B,B,B A,A,A,A, C,C,C ]
#
c_f.labels_required(labels)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
counts = torch.bincount(labels)
nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
nonzero_counts = counts[nonzero_indices]
if nonzero_counts.unique().size(0) != 1:
raise ValueError(
"All classes must have the same number of elements in the labels.\n"
"The given labels have the following number of elements: {}.\n"
"You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
nonzero_counts.cpu().tolist()
)
)
batch_size = embeddings.size(0)
num_classes_batch = batch_size // torch.unique(labels).size(0)
mask = 1.0 - torch.eye(batch_size)
mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)
sims = self.distance(embeddings)
sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
sims_ranks = torch.sum(sims_sigm, dim=-1) + 1
xs = embeddings.view(
num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
)
pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
pos_mask = (
pos_mask.unsqueeze(dim=0)
.unsqueeze(dim=0)
.repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
)
# Circumvent the shape check in forward method
xs_norm = self.distance.maybe_normalize(xs, dim=-1)
sims_pos = self.distance.compute_mat(xs_norm, xs_norm)
sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
1, 1, batch_size // num_classes_batch, 1
)
sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)
sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
sims_diff.device
)
sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1
g = batch_size // num_classes_batch
ap = torch.zeros(batch_size).to(embeddings.device)
for i in range(num_classes_batch):
for j in range(g):
pos_rank = sims_pos_ranks[i, j]
all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
ap[i * g + j] = torch.sum(pos_rank / all_rank) / g
miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
loss = (1 - ap) * miner_weights
return {
"ap_loss": {
"losses": loss,
"indices": c_f.torch_arange_from_size(loss),
"reduction_type": "element",
}
}