我有:
inp = torch.randn(4, 1040, 161)
还有一个名为indices
的张量,其值是:
tensor([[124, 583, 158, 529],
[172, 631, 206, 577]], device='cuda:0')
我想要等同于:
inp0 = inp[:,124:172,:]
inp1 = inp[:,583:631,:]
inp2 = inp[:,158:206,:]
inp3 = inp[:,529:577,:]
除所有加在一起外,其.size为[4, 48, 161]
。我该怎么做?
当前,我的解决方案是一个for
循环:
left_indices = torch.empty(inp.size(0), self.side_length, inp.size(2))
for batch_index in range(len(inp)):
print(left_indices_start[batch_index].item())
left_indices[batch_index] = inp[batch_index, left_indices_start[batch_index].item():left_indices_end[batch_index].item()]
答案 0 :(得分:2)
在这里您可以进行编辑(编辑:在执行以下操作之前,您可能需要使用tensor=tensor.cpu()
将张量复制到cpu):
index = tensor([[124, 583, 158, 529],
[172, 631, 206, 577]], device='cuda:0')
#create a concatenated list of ranges of indices you desire to slice
indexer = np.r_[tuple([np.s_[i:j] for (i,j) in zip(index[0,:],index[1,:])])]
#slice using numpy indexing
sliced_inp = inp[:, indexer, :]
这是它的工作方式:
np.s_[i:j]
创建一个索引对象的切片对象(仅是一个范围),该索引对象从start = i
到end = j
。
np.r_[i:j, k:m]
创建切片(i,j)
和(k,m)
中所有索引的列表(您可以将更多的切片传递到np.r_
一次将它们全部串联起来。这是一个示例只能连接两个切片。)
因此,indexer
通过连接切片列表(每个切片是索引范围)来创建ALL索引列表。
更新:如果您需要删除间隔重叠和排序间隔:
indexer = np.unique(indexer)
如果您要删除间隔重叠,但不进行排序并保持原始顺序(以及重叠的首次出现)
uni = np.unique(indexer, return_index=True)[1]
indexer = [indexer[index] for index in sorted(uni)]
答案 1 :(得分:1)
inp = torch.randn(4, 1040, 161)
indices = torch.tensor([[124, 583, 158, 529],
[172, 631, 206, 577]])
k = zip(indices[0], indices[1])
for i,j in k:
print(inp[:,i:j,:])
您可以像这样实现它... zip函数有助于将索引张量转换为元组列表,您可以直接通过for循环使用
希望它可以帮助您。...