如何将PyTorch张量与另一个张量切片?

时间:2020-04-18 13:54:38

标签: python numpy pytorch tensor

我有:

inp =  torch.randn(4, 1040, 161)

还有一个名为indices的张量,其值是:

tensor([[124, 583, 158, 529],
        [172, 631, 206, 577]], device='cuda:0')

我想要等同于:

inp0 = inp[:,124:172,:]
inp1 = inp[:,583:631,:]
inp2 = inp[:,158:206,:]
inp3 = inp[:,529:577,:]

除所有加在一起外,其.size为[4, 48, 161]。我该怎么做?

当前,我的解决方案是一个for循环:

            left_indices = torch.empty(inp.size(0), self.side_length, inp.size(2))
            for batch_index in range(len(inp)):
                print(left_indices_start[batch_index].item())
                left_indices[batch_index] = inp[batch_index, left_indices_start[batch_index].item():left_indices_end[batch_index].item()]

2 个答案:

答案 0 :(得分:2)

在这里您可以进行编辑(编辑:在执行以下操作之前,您可能需要使用tensor=tensor.cpu()将张量复制到cpu):

index = tensor([[124, 583, 158, 529],
    [172, 631, 206, 577]], device='cuda:0')
#create a concatenated list of ranges of indices you desire to slice
indexer = np.r_[tuple([np.s_[i:j] for (i,j) in zip(index[0,:],index[1,:])])]
#slice using numpy indexing
sliced_inp = inp[:, indexer, :]

这是它的工作方式:

np.s_[i:j]创建一个索引对象的切片对象(仅是一个范围),该索引对象从start = i到end = j

np.r_[i:j, k:m]创建切片(i,j)(k,m)中所有索引的列表(您可以将更多的切片传递到np.r_一次将它们全部串联起来。这是一个示例只能连接两个切片。)

因此,indexer通过连接切片列表(每个切片是索引范围)来创建ALL索引列表。

更新:如果您需要删除间隔重叠和排序间隔:

indexer = np.unique(indexer)

如果您要删除间隔重叠,但不进行排序并保持原始顺序(以及重叠的首次出现)

uni = np.unique(indexer, return_index=True)[1]
indexer = [indexer[index] for index in sorted(uni)]

答案 1 :(得分:1)

inp =  torch.randn(4, 1040, 161)   
indices = torch.tensor([[124, 583, 158, 529],
            [172, 631, 206, 577]])
k = zip(indices[0], indices[1])
for i,j in k:
    print(inp[:,i:j,:])

您可以像这样实现它... zip函数有助于将索引张量转换为元组列表,您可以直接通过for循环使用

希望它可以帮助您。...