调试时遇到一个奇怪的问题。
>>> cos
>>>
tensor([ 0.3869, 0.2857, 0.4931, 0.5086, 0.6757, 0.6417, 0.3773, 0.4084,
0.2496, 0.7558, 0.4305, 0.3839, 0.1892, 0.8675, 0.3392, 0.5415,
0.4421, 0.2782, 0.5187, 0.2672, 0.2896, 0.5031, 0.4791, 0.3528,
0.2577, 0.3932, 0.2554, 0.4925, 0.8496, 0.1264, 0.5594, 0.8667,
···
0.4919, 0.4073, 0.6890, 0.3976, 0.5691, 0.0741, 0.6420, 0.4249,
0.2785])
我想计算多少元素大于零,所以我尝试:
>>> sum(cos>0)
>>> tensor(181, dtype=torch.uint8) # fault answer
>>> torch.sum(cos>0)
>>> tensor(437) # correct answer
我也重复了上述操作,但是两个答案都正确了。
>>> sum(cos[:100]>0)
>>> tensor(100, dtype=torch.uint8)
>>> torch.sum(cos[:100]>0)
>>> tensor(100)
我在pytorch 1.0.2中运行它。