import torch import torch.nn.functional as F from ..distances import CosineSimilarity from ..utils import common_functions as c_f from ..utils import loss_and_miner_utils as lmu from .base_metric_loss_function import BaseMetricLossFunction class SmoothAPLoss(BaseMetricLossFunction): """ Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163 """ def __init__(self, temperature=0.01, **kwargs): super().__init__(**kwargs) c_f.assert_distance_type(self, CosineSimilarity) self.temperature = temperature def get_default_distance(self): return CosineSimilarity() # Implementation is based on the original repository: # https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87 def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): # The loss expects labels such that there is the same number of elements for each class # The number of classes is not important, nor their order, but the number of elements must be the same, eg. # # The following label is valid: # [ A,A,A, B,B,B, C,C,C ] # The following label is NOT valid: # [ B,B,B A,A,A,A, C,C,C ] # c_f.labels_required(labels) c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) counts = torch.bincount(labels) nonzero_indices = torch.nonzero(counts, as_tuple=True)[0] nonzero_counts = counts[nonzero_indices] if nonzero_counts.unique().size(0) != 1: raise ValueError( "All classes must have the same number of elements in the labels.\n" "The given labels have the following number of elements: {}.\n" "You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format( nonzero_counts.cpu().tolist() ) ) batch_size = embeddings.size(0) num_classes_batch = batch_size // torch.unique(labels).size(0) mask = 1.0 - torch.eye(batch_size) mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1) sims = self.distance(embeddings) sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1) sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1) sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device) sims_ranks = torch.sum(sims_sigm, dim=-1) + 1 xs = embeddings.view( num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1) ) pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch) pos_mask = ( pos_mask.unsqueeze(dim=0) .unsqueeze(dim=0) .repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1) ) # Circumvent the shape check in forward method xs_norm = self.distance.maybe_normalize(xs, dim=-1) sims_pos = self.distance.compute_mat(xs_norm, xs_norm) sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat( 1, 1, batch_size // num_classes_batch, 1 ) sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2) sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to( sims_diff.device ) sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1 g = batch_size // num_classes_batch ap = torch.zeros(batch_size).to(embeddings.device) for i in range(num_classes_batch): for j in range(g): pos_rank = sims_pos_ranks[i, j] all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g] ap[i * g + j] = torch.sum(pos_rank / all_rank) / g miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype) loss = (1 - ap) * miner_weights return { "ap_loss": { "losses": loss, "indices": c_f.torch_arange_from_size(loss), "reduction_type": "element", } }
Memory