import torch class BatchedDistance(torch.nn.Module): def __init__(self, distance, iter_fn=None, batch_size=32): super().__init__() self.distance = distance self.iter_fn = iter_fn self.batch_size = batch_size def forward(self, query_emb, ref_emb=None): ref_emb = ref_emb if ref_emb is not None else query_emb n = query_emb.shape[0] for s in range(0, n, self.batch_size): e = s + self.batch_size L = query_emb[s:e] mat = self.distance(L, ref_emb) self.iter_fn(mat, s, e) def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.distance, name)
Memory