import torch
from .large_margin_softmax_loss import LargeMarginSoftmaxLoss
class SphereFaceLoss(LargeMarginSoftmaxLoss):
# implementation of https://arxiv.org/pdf/1704.08063.pdf
def scale_logits(self, logits, embeddings):
embedding_norms = torch.norm(embeddings, p=2, dim=1)
return logits * embedding_norms.unsqueeze(1) * self.scale