向量化Pytorch张量索引操作

时间:2020-08-17 14:21:09

标签: python numpy pytorch vectorization

我正在尝试对PyTorch中的操作进行矢量化处理,但是我不确定该怎么做。这是现在使用for循环的代码。 'm'是具有int键和1d张量作为值的字典。 输出掩码为2d。 L是层数,此循环可能是必需的。因此,我希望主要替换2个内部循环。我当时想以某种方式使用torch.gather,但没有成功

for l in range(L):
    mask = torch.zeros((m[l].shape[0], m[l-1].shape[0]))
    for i in range(m[l].shape[0]):
        for j in range(m[l-1].shape[0]):
            mask[i,j] = R[m[l-1][j], m[l][i]]
    masks.append(mask)

我将不胜感激!预先感谢。

1 个答案:

答案 0 :(得分:0)

我想我自己找到了答案。如此处所述,您可以使用numpy高级索引:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html

然后将其归结为以下内容

for l in range(L):
    mask = R[m[l-1], m[l][:, np.newaxis]]
    masks.append(mask)

np.newaxis确保对每一行重复列索引。