如何从二维pytorch张量获取最大元素的行和列索引?

时间:2020-02-26 19:55:27

标签: python indexing pytorch tensor

有什么方法可以检索二维pytorch张量中包含的最大元素的行索引和列索引?例如,请参见下面的pytorch张量a

a
>> torch.tensor([1,2,3],
                [9,5,4],
                [6,7,8])

张量a中最大的元素是9,它发生在第二行的第一列。如果将其更改为从零开始的python列和行索引,则该元素的列索引将为0,而行索引将为1。

有什么方法可以从2维pytorch张量a中检索索引[1,0]吗?

1 个答案:

答案 0 :(得分:0)

不幸的是,没有内置方法。 但是,您可以使用numpy:

np.unravel_index(torch.argmax(a), a.shape)

否则,您需要编写自己的逻辑,例如:

def unravel_index(flat_idx, shape): 
     multi_idx = [] 
     r = flat_idx 
     for s in shape[:-1]: 
         multi_idx.append(r // s) 
         r = r % s 
     multi_idx.append(r % s) 
     return multi_idx