pytorch Argmax中发生冲突时的索引选择

时间:2019-03-13 10:35:37

标签: python pytorch tensor argmax

我一直在尝试学习张量运算,而这使我陷入了循环。
假设我有一个张量t:

    t = torch.tensor([
        [1,0,0,2],
        [0,3,3,0],
        [4,0,0,5]
    ], dtype  = torch.float32)

现在这是2级张量,我们可以为每个级/维应用argmax。 假设我们将其应用于dim = 1

t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))

现在我们可以看到结果与预期的一样,沿着dim = 1的张量具有2、3和5作为最大元素。但是在3上存在冲突。有两个完全相似的值。
如何解决?它是任意选择的吗?是否有选择顺序,例如L-R,较高的索引值?
感谢您对如何解决此问题有任何见解!

1 个答案:

答案 0 :(得分:3)

我本人几次偶然发现这是一个好问题。最简单的答案是,不能保证 setTimeout(function() { $('._my_save_button').trigger('click'); }, 3000); (或C:\Users\tushar\PycharmProjects>conda activate YourCondaEnv && tensorboard --logdir="NewTF" ,当指定dim时也返回索引)将始终返回相同的索引。相反,它将任何有效索引返回到argmax值,可能是随机的。正如this thread in the official forum所讨论的,这被认为是期望的行为。 (我知道不久前我读过另一个线程,这使它更加明确,但我找不到它。)

已经说过了,因为这种行为在我的用例中是不可接受的,所以我编写了以下函数来查找最左边和最右边的索引(请注意torch.argmax是您传入的函数对象):

torch.max(x, dim=k)

希望他们会有所帮助! (哦,如果您有一个更好的实现相同的实现,请告诉我)