PyTorch:使用行索引的2D张量索引2D张量

时间:2019-04-25 21:18:56

标签: python pytorch

我有一个形状为a的火炬张量(x, n)和另一个形状为b的张量(y, n),其中y <= xb的每一列都包含a的行索引序列,而我想做的就是以某种方式用a索引b以便获得形状为(y, n)的张量,其中第i列包含a[:, i][b[:, i]](不确定这是否是表达它的正确方法)。

以下是一个示例(其中x = 5,y = 3和n = 4):

import torch

a = torch.Tensor(
    [[0.1, 0.2, 0.3, 0.4],
     [0.6, 0.7, 0.8, 0.9],
     [1.1, 1.2, 1.3, 1.4],
     [1.6, 1.7, 1.8, 1.9],
     [2.1, 2.2, 2.3, 2.4]]
)

b = torch.LongTensor(
    [[0, 3, 1, 2],
     [2, 2, 2, 0],
     [1, 1, 0, 4]]
)

# How do I get from a and b to c
# (so that I can also assign to those elements in a)?

c = torch.Tensor(
    [[0.1, 1.7, 0.8, 1.4],
     [1.1, 1.2, 1.3, 0.4],
     [0.6, 0.7, 0.3, 2.4]]
)

我无法解决这个问题。我要寻找的是一种既不会产生张量c,又让我将与c相同形状的张量分配给{{1 }}由。

1 个答案:

答案 0 :(得分:0)

我尝试使用index_select,但它仅支持1-dim数组作为索引。

bt = b.transpose(0, 1)
at = a.transpose(0, 1)
ct = [torch.index_select(at[i], dim=0, index=bt[i]) for i in range(len(at))]
c  = torch.stack(ct).transpose(0, 1)
print(c)
"""
tensor([[0.1000, 1.7000, 0.8000, 1.4000],
        [1.1000, 1.2000, 1.3000, 0.4000],
        [0.6000, 0.7000, 0.3000, 2.4000]])
"""

这可能不是最佳解决方案,但希望这至少对您有帮助。