AttributeError:'torch.FloatTensor'对象没有属性'item'

时间:2018-04-29 12:30:59

标签: python-3.x pytorch

以下是代码:

    from __future__ import print_function
    from itertools import count

    import torch 
    import torch.autograd
    import torch.nn.functional as F

    POLY_DEGREE = 4
    W_target = torch.randn(POLY_DEGREE, 1) * 5
    b_target = torch.randn(1) * 5


    def make_features(x):
        x = x.unsqueeze(1)
        return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)


    def f(x):
        return x.mm(W_target) + b_target.item()

这导致以下错误消息:

AttributeError: 'torch.FloatTensor' object has no attribute 'item'

我该如何解决这个问题?

1 个答案:

答案 0 :(得分:1)

函数item()是PyTorch 0.4.0的新功能。使用早期版本的PyTorch时,您将收到此错误。 所以你可以升级你的PyTorch版本来解决这个问题。

编辑:

我再次通过你的例子。您希望使用item()进行归档? 在你的情况下,item()应该只给出张量中的(python)浮点值。 你为什么要用这个?你可以省略item()

所以:

def f(x):
    return x.mm(W_target) + b_target

而不是:

def f(x):
    return x.mm(W_target) + b_target.item()

这应该对你有用,在PyTorch 0.4.0中没有区别。省略item()

也更有效率