我试图获得一个矩阵,其中每个元素的计算如下:
X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
for j in range(batch_size):
M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()
按元素计算M
的速度非常慢,是否有关于如何使用矩阵乘法替换for循环的建议?
谢谢。
答案 0 :(得分:2)
如果您想sum()
暗淡,可以将2D问题“提升”到3D并求和:
M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)
工作原理:
X[:, None, :]
和X_[:, None, :]
的大小为(batch_size, 1, dim)
的3D,而Y[None, ...]
的大小为(1, batch_size, dim)
。
将X_[:, None, :] * Y[None, ...]
pytorch broadcasts的尺寸乘以适当的尺寸以得到尺寸(batch_size, batch_size, dim)
的结果。
最后,您sum()
仅在最后一个维度(dim=2)
上才能获得大小为M
的输出(batch_size, batch_size)
。
这里的诀窍是利用broadcasting来完成的。