有什么方法可以检索二维pytorch
张量中包含的最大元素的行索引和列索引?例如,请参见下面的pytorch
张量a
:
a
>> torch.tensor([1,2,3],
[9,5,4],
[6,7,8])
张量a
中最大的元素是9,它发生在第二行的第一列。如果将其更改为从零开始的python列和行索引,则该元素的列索引将为0,而行索引将为1。
有什么方法可以从2维pytorch张量a
中检索索引[1,0]吗?
答案 0 :(得分:0)
不幸的是,没有内置方法。 但是,您可以使用numpy:
np.unravel_index(torch.argmax(a), a.shape)
否则,您需要编写自己的逻辑,例如:
def unravel_index(flat_idx, shape):
multi_idx = []
r = flat_idx
for s in shape[:-1]:
multi_idx.append(r // s)
r = r % s
multi_idx.append(r % s)
return multi_idx