如何有效地在PyTorch中计算批次成对距离

时间:2019-03-12 16:11:43

标签: deep-learning pytorch tensor

我有形状为BxNxD的张量X和形状为BxNxD的张量。

我想计算批次中每个元素的成对距离,即我是BxMxN张量。

我该怎么做?

https://github.com/pytorch/pytorch/issues/9406在此主题上有一些讨论,但我不理解,因为有许多实现细节,而没有强调实际解决方案。

一种幼稚的方法是将答案用于非批处理的成对距离,如下所述:https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065,即

import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))


def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)


out = []
for b in range(B):
    out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)

如何在不循环B的情况下执行此操作? 谢谢

1 个答案:

答案 0 :(得分:2)

我遇到了类似的问题,并花了一些时间来找到最简单,最快的解决方案。现在,您可以使用PyTorch cdist计算批量距离,这将为您提供<div class="className variable1variable2"><div> 张量:

BxMxN

此外,如果您只想计算两个矩阵的每一对行之间的距离,它也可以很好地工作。