import numpy as np import torch from ..distances import LpDistance from ..utils import common_functions as c_f from ..utils import loss_and_miner_utils as lmu from .base_metric_loss_function import BaseMetricLossFunction class AngularLoss(BaseMetricLossFunction): """ Implementation of https://arxiv.org/abs/1708.01682 Args: alpha: The angle (as described in the paper), specified in degrees. """ def __init__(self, alpha=40, **kwargs): super().__init__(**kwargs) c_f.assert_distance_type( self, LpDistance, p=2, power=1, normalize_embeddings=True ) self.alpha = torch.tensor(np.radians(alpha)) self.add_to_recordable_attributes(list_of_names=["alpha"], is_stat=False) self.add_to_recordable_attributes(list_of_names=["average_angle"], is_stat=True) def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_required(labels) anchors, positives, keep_mask, anchor_idx, positive_idx = self.get_pairs( embeddings, labels, indices_tuple, ref_emb, ref_labels ) if anchors is None: return self.zero_losses() sq_tan_alpha = torch.tan(self.alpha) ** 2 ap_dot = torch.sum(anchors * positives, dim=1, keepdim=True) ap_matmul_embeddings = torch.matmul( (anchors + positives), (ref_emb.unsqueeze(2)) ) ap_matmul_embeddings = ap_matmul_embeddings.squeeze(2).t() final_form = (4 * sq_tan_alpha * ap_matmul_embeddings) - ( 2 * (1 + sq_tan_alpha) * ap_dot ) losses = lmu.logsumexp(final_form, keep_mask=keep_mask, add_one=True) return { "loss": { "losses": losses, "indices": (anchor_idx, positive_idx), "reduction_type": "pos_pair", } } def get_pairs(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): a1, p, a2, _ = lmu.convert_to_pairs(indices_tuple, labels, ref_labels) if len(a1) == 0 or len(a2) == 0: return [None] * 5 anchors = self.distance.normalize(embeddings[a1]) positives = self.distance.normalize(ref_emb[p]) keep_mask = labels[a1].unsqueeze(1) != ref_labels.unsqueeze(0) self.set_stats(anchors, positives, embeddings, ref_emb, keep_mask) return anchors, positives, keep_mask, a1, p def set_stats(self, anchors, positives, embeddings, ref_emb, keep_mask): if self.collect_stats: with torch.no_grad(): centers = (anchors + positives) / 2 ap_dist = self.distance.pairwise_distance(anchors, positives) nc_dist = self.distance.get_norm( centers - ref_emb.unsqueeze(1), dim=2 ).t() angles = torch.atan(ap_dist.unsqueeze(1) / (2 * nc_dist)) average_angle = torch.sum(angles[keep_mask]) / torch.sum(keep_mask) self.average_angle = np.degrees(average_angle.item())
Memory