Pytorch:沿多轴具有张量的索引或一次分散到多个索引

时间:2020-06-05 03:13:11

标签: python numpy indexing pytorch

我正在尝试更新Pytorch中多维张量的非常特定的索引,并且我不确定如何访问正确的索引。我可以在Numpy中以非常简单的方式做到这一点:

import numpy as np
#set up the array containing the data
data = 100*np.ones((10,10,2))
data[5:,:,:] = 0
#select the data points that I want to update
idxs = np.nonzero(data.sum(2))
#generate the updates that I am going to do
updates = np.random.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates

我需要在Pytorch中实现此功能,但是我不确定如何执行此操作。似乎我需要scatter函数,但是该函数只能沿一个维度运行,而不需要多个维度。我该怎么办?

1 个答案:

答案 0 :(得分:3)

除了torch.nonzero以外,这些操作在其PyTorch副本中的工作原理完全相同,https://stackoverflow.com/a/19168875/1661745默认情况下返回大小为 [z,n] (其中 z 是非零元素的数量,而 n 是维数),而不是大小为 [z] n 张量的元组(就像NumPy一样),但是可以通过设置as_tuple=True来更改该行为。

除此之外,您可以将其直接转换为PyTorch,但是您需要确保类型匹配,因为您无法将类型为torch.long(默认为torch.randint)的张量分配给张量类型为torch.float(默认为torch.ones)。在这种情况下,data可能具有类型torch.long

#set up the array containing the data
data = 100*torch.ones((10,10,2), dtype=torch.long)
data[5:,:,:] = 0
#select the data points that I want to update
idxs = torch.nonzero(data.sum(2), as_tuple=True)
#generate the updates that I am going to do
updates = torch.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates