import torch from ..utils import loss_and_miner_utils as lmu from .base_distance import BaseDistance class LpDistance(BaseDistance): def __init__(self, **kwargs): super().__init__(**kwargs) assert not self.is_inverted def compute_mat(self, query_emb, ref_emb): dtype, device = query_emb.dtype, query_emb.device if ref_emb is None: ref_emb = query_emb if dtype == torch.float16: # cdist doesn't work for float16 rows, cols = lmu.meshgrid_from_sizes(query_emb, ref_emb, dim=0) output = torch.zeros(rows.size(), dtype=dtype, device=device) rows, cols = rows.flatten(), cols.flatten() distances = self.pairwise_distance(query_emb[rows], ref_emb[cols]) output[rows, cols] = distances return output else: return torch.cdist(query_emb, ref_emb, p=self.p) def pairwise_distance(self, query_emb, ref_emb): return torch.nn.functional.pairwise_distance(query_emb, ref_emb, p=self.p)
Memory