torch.addmm收到了无效的参数组合

时间:2017-11-23 18:33:19

标签: python pytorch tensor

在pytorch的官方网页上,我看到了以下代码和答案:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4]
,
 2
 0
 0
 0
[torch.LongTensor of size 4]
)

我知道第一个结果对应于每行的最大数量,但我没有得到第二个张量(LongTensor)

我尝试了其他随机示例,在pytorch.max之后,我找到了这些结果

0.9477  1.0090  0.8348 -1.3513
-0.4861  1.2581  0.3972  1.5751
-1.2277 -0.6201 -1.0553  0.6069
 0.1688  0.1373  0.6544 -0.7784
[torch.FloatTensor of size 4x4]

(
 1.0090
 1.5751
 0.6069
 0.6544
[torch.FloatTensor of size 4]
, 
 1
 3
 3
 2
[torch.LongTensor of size 4]
)

有谁能告诉我这些LongTensor数据到底意味着什么?我认为这是张量之间的一个奇怪的演员,但是在一个简单的浮动张量投射后,我看到它只是削减小数

由于

1 个答案:

答案 0 :(得分:1)

它只是沿着查询的维度告诉原始张量中max元素的 index

E.g。

0.9477  1.0090  0.8348 -1.3513
-0.4861  1.2581  0.3972  1.5751
-1.2277 -0.6201 -1.0553  0.6069
 0.1688  0.1373  0.6544 -0.7784
[torch.FloatTensor of size 4x4]

# torch.max(a, 1)
(
 1.0090
 1.5751
 0.6069
 0.6544
[torch.FloatTensor of size 4]
, 
 1
 3
 3
 2
[torch.LongTensor of size 4]
)

torch.LongTensor中的上述示例中,

1是原始张量中1.0090的索引(torch.FloatTensor)
3是原始张量中1.5751的索引(torch.FloatTensor)
3是原始张量中0.6069的索引(torch.FloatTensor)
2是原始张量中0.6544的索引(torch.FloatTensor)

沿维度1

相反,如果您已请求torch.max(a, 0),则torch.LongTensor中的条目将对应于原始张量中沿维度0 的最大元素的indices >