在pytorch中,为什么sum(tensor)有时返回错误结果而不是torch.sum(tensor)?

时间:2019-03-05 17:40:22

标签: python pytorch

调试时遇到一个奇怪的问题。

>>> cos
>>>
tensor([ 0.3869,  0.2857,  0.4931,  0.5086,  0.6757,  0.6417,  0.3773,  0.4084,
         0.2496,  0.7558,  0.4305,  0.3839,  0.1892,  0.8675,  0.3392,  0.5415,
         0.4421,  0.2782,  0.5187,  0.2672,  0.2896,  0.5031,  0.4791,  0.3528,
         0.2577,  0.3932,  0.2554,  0.4925,  0.8496,  0.1264,  0.5594,  0.8667,
         ···
         0.4919,  0.4073,  0.6890,  0.3976,  0.5691,  0.0741,  0.6420,  0.4249,
         0.2785])

我想计算多少元素大于零,所以我尝试:

>>> sum(cos>0)
>>> tensor(181, dtype=torch.uint8)  # fault answer
>>> torch.sum(cos>0)
>>> tensor(437)                     # correct answer

我也重复了上述操作,但是两个答案都正确了。

>>> sum(cos[:100]>0)
>>> tensor(100, dtype=torch.uint8)
>>> torch.sum(cos[:100]>0)
>>> tensor(100)

我在pytorch 1.0.2中运行它。

0 个答案:

没有答案