在pytorch的官方网页上,我看到了以下代码和答案:
>> a = torch.randn(4, 4)
>> a
0.0692 0.3142 1.2513 -0.5428
0.9288 0.8552 -0.2073 0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666 0.4862 -0.6628
torch.FloatTensor of size 4x4]
>>> torch.max(a, 1)
(
1.2513
0.9288
1.0695
0.7426
[torch.FloatTensor of size 4]
,
2
0
0
0
[torch.LongTensor of size 4]
)
我知道第一个结果对应于每行的最大数量,但我没有得到第二个张量(LongTensor)
我尝试了其他随机示例,在pytorch.max之后,我找到了这些结果
0.9477 1.0090 0.8348 -1.3513
-0.4861 1.2581 0.3972 1.5751
-1.2277 -0.6201 -1.0553 0.6069
0.1688 0.1373 0.6544 -0.7784
[torch.FloatTensor of size 4x4]
(
1.0090
1.5751
0.6069
0.6544
[torch.FloatTensor of size 4]
,
1
3
3
2
[torch.LongTensor of size 4]
)
有谁能告诉我这些LongTensor数据到底意味着什么?我认为这是张量之间的一个奇怪的演员,但是在一个简单的浮动张量投射后,我看到它只是削减小数
由于
答案 0 :(得分:1)
它只是沿着查询的维度告诉原始张量中max元素的 index 。
E.g。
0.9477 1.0090 0.8348 -1.3513
-0.4861 1.2581 0.3972 1.5751
-1.2277 -0.6201 -1.0553 0.6069
0.1688 0.1373 0.6544 -0.7784
[torch.FloatTensor of size 4x4]
# torch.max(a, 1)
(
1.0090
1.5751
0.6069
0.6544
[torch.FloatTensor of size 4]
,
1
3
3
2
[torch.LongTensor of size 4]
)
在torch.LongTensor
中的上述示例中,
1
是原始张量中1.0090
的索引(torch.FloatTensor)
3
是原始张量中1.5751
的索引(torch.FloatTensor)
3
是原始张量中0.6069
的索引(torch.FloatTensor)
2
是原始张量中0.6544
的索引(torch.FloatTensor)
沿维度1 。
相反,如果您已请求torch.max(a, 0)
,则torch.LongTensor
中的条目将对应于原始张量中沿维度0 的最大元素的indices
>