使用2d张量索引3d张量

时间:2019-04-11 08:30:31

标签: pytorch

我有一个形状为source的3d张量(bsz x slen1 x nhd)和一个形状为index的2d张量(bsz x slen2)。更具体地说,我有:

source = 32 x 20 x 768
index  = 32 x 16

index张量中的每个值都在[0, 19]之间,source是根据32 x 16 x 768张量的第二次变暗的所需矢量的索引。

索引后,我期望输出形状为bsz, _, nhid = source.size() _, slen = index.size() source = source.reshape(-1, nhid) source = source[index.reshape(-1), :] source = source.reshape(bsz, slen, nhid) 的张量。

目前我正在这样做:

source = torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820],
     [ 0.3490, -0.0198,  0.7928]],

    [[-0.0973,  2.3106, -1.8358],
     [-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

index = torch.LongTensor([[0, 1, 2, 3], 
                          [1, 2, 3, 4]])

因此,我将3d源张量转换为2d张量,并将2d索引张量转换为1d张量,然后执行索引。这是正确的吗?

还有更好的方法吗?

更新

我检查了我的代码没有给出预期的结果。为了解释我想要的内容,我提供了以下代码段。

torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820]],

    [[-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

我希望输出张量为:

AddType image/svg+xml .svg .svgz

2 个答案:

答案 0 :(得分:2)

我已经解决了问题。因此,我实际上需要定义一个偏移量。以下代码对我有用。

index = torch.LongTensor([[0, 1, 2, 3], [1, 2, 3, 4]])
offset = torch.arange(0, source.size(0) * source.size(1), source.size(1))
index = index + offset.unsqueeze(1)

source = source.reshape(-1, source.shape[-1])[index]

答案 1 :(得分:1)

更新

source[torch.arange(source.shape[0]).unsqueeze(-1), index]

请注意,torch.arange(source.shape[0]).unsqueeze(-1)给出:

tensor([[0],
        [1]])  # 2 x 1

index是:

tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]])  # 2 x 4

arange索引批次维度,而index同时索引slen1维度。 unsqueeze调用将额外的x 1维添加到arange结果中,以便可以一起广播两者。