TypeError:需要一个float

时间:2017-08-03 16:46:25

标签: python math pytorch

我有一个pytorch变量:

preds[4,4]
Out[305]: 
Variable containing:
-96.7809
[torch.cuda.FloatTensor of size 1 (GPU 0)]

我想做以下事情:

 import math
 x=preds[4,4]
 y=maths.exp(x)
 z= y / (y+1)

然而,当我做:

y=maths.exp(x)

我收到以下错误:

   math.exp(preds[4,4])
TypeError: a float is required

如何将火炬变量转换为浮点数以便能够执行这些操作?

谢谢

1 个答案:

答案 0 :(得分:2)

索引Variable对象不会将其转换为标量。它仍然是Variable个对象。但是为numpy array建立索引的确如此。因此,将Variable对象转换为numpy然后以您希望的方式编制索引应该可以解决问题。

但是将Variable转换为numpy时会有一些小的陷阱。

如果preds是存储在cpu内存中的Variable,则可以执行此操作。

nparr = preds.data.numpy()
x = nparr[4, 4] 

但是,如果preds在gpu内存中,则必须首先将Variable转换为cpu内存,然后再将其转换为numpy对象,如下所示:

preds = preds.cpu()

然后按照上面的步骤进行操作。

nparr = preds.data.numpy()
x = nparr[4, 4]

在这两种情况下x都是标量(在你的情况下是浮点数),你可以在你选择的任何数学运算中使用它。

编辑:

是的,@ mememex是对的,您也可以直接索引tensor中包含的Variable来提取任何给定索引的标量值。

像这样:

 x = preds.data[4, 4]