import torch from .large_margin_softmax_loss import LargeMarginSoftmaxLoss class SphereFaceLoss(LargeMarginSoftmaxLoss): # implementation of https://arxiv.org/pdf/1704.08063.pdf def scale_logits(self, logits, embeddings): embedding_norms = torch.norm(embeddings, p=2, dim=1) return logits * embedding_norms.unsqueeze(1) * self.scale
Memory