如何向量化以下python代码

时间:2018-11-23 07:08:15

标签: numpy matrix vectorization matrix-multiplication pytorch

我试图获得一个矩阵,其中每个元素的计算如下:

X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
    for j in range(batch_size):
        M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()

按元素计算M的速度非常慢,是否有关于如何使用矩阵乘法替换for循环的建议?

谢谢。

1 个答案:

答案 0 :(得分:2)

如果您想sum()暗淡,可以将2D问题“提升”到3D并求和:

M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)

工作原理

X[:, None, :]X_[:, None, :]的大小为(batch_size, 1, dim)的3D,而Y[None, ...]的大小为(1, batch_size, dim)

X_[:, None, :] * Y[None, ...] pytorch broadcasts的尺寸乘以适当的尺寸以得到尺寸(batch_size, batch_size, dim)的结果。
最后,您sum()仅在最后一个维度(dim=2)上才能获得大小为M的输出(batch_size, batch_size)

这里的诀窍是利用broadcasting来完成的。