torch.pow不起作用

时间:2017-11-24 09:30:50

标签: numpy pytorch

我正在尝试使用PyTorch创建自定义丢失函数,并且遇到了一个简单的错误。

当我尝试使用torch.pow取得PyTorch变量的指数时,我收到以下错误消息:

属性错误:' torch.LongTensor'对象没有属性' pow'

在python终端中,我创建了一个简单的变量,并试图做同样的事情,并收到同样的错误。这是一个应该重新创建问题的片段:

a <- dat_KOR14_16 %>% 
            group_by(MCLASSNAME) %>%
            summarise(new_avg_price = mean(AVG_PRICE))

我无法找到有关此问题的任何信息,搜索结果中也没有显示任何信息。帮助

编辑:当我尝试使用torch.sqrt()时也会出现此问题 编辑:如果我尝试做同样的问题也会发生

import torch
from torch.autograd import Variable
import numpy as np

v = Variable(torch.from_numpy(np.array([1, 2, 3, 4])))
torch.pow(v, 2)

pow绝对是v的一种方法,而且文档明确指出pow是一种存在的方法,并且作为其参数采用张量。我真的不知道这是怎么回事,在我看来,文档只是错误的,这些方法实际上并不起作用。

1 个答案:

答案 0 :(得分:2)

你需要将张量初始化为浮点数,因为pow总是返回一个浮点数。

import torch
from torch.autograd import Variable
import numpy as np

v = Variable(torch.from_numpy(np.array([1, 2, 3, 4], dtype="float32")))
torch.pow(v, 2)

之后你可以把它强制转换成整数

torch.pow(v, 2).type(torch.LongTensor)

产量

Variable containing:
  1
  4
  9
 16
[torch.LongTensor of size 4]