Pytorch:是否有类似于torch.argmax的函数可以真正保持原始数据的大小?

时间:2018-11-02 10:07:39

标签: python pytorch tensor

例如, 代码是

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)

input

tensor([[ 1.5742,  0.8183, -2.3005, -1.1650, -0.2451],
       [ 1.0553,  0.6021, -0.4938, -1.5379, -1.2054],
       [-0.1728,  0.8372, -1.9181, -0.9110,  0.2422]])

result

tensor([[ 0,  2,  1,  2,  2]])

但是,我想要这样的结果

tensor([[ 1,  0,  0,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  1,  0,  1,  1]])

2 个答案:

答案 0 :(得分:1)

最后,我解决了。但是此解决方案可能无效。 代码如下,

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
result_0 = result == 0
result_1 = result == 1
result_2 = result == 2
result = torch.cat((result_0, result_1, result_2), 0)

答案 1 :(得分:0)

Numpy 有函数 ';alert(String.fromCharCode(88,83,83))//\';alert(String.fromCharCode(88,83,83))//";alert(String.fromCharCode(88,83,83))//\";alert(String.fromCharCode(88,83,83))//--></SCRIPT>">'><SCRIPT>alert(String.fromCharCode(88,83,83))</SCRIPT>=&{} 可以回答这个问题,我还没有找到这个函数的等效火炬。

put_along_axis