我有一个pytorch变量:
preds[4,4]
Out[305]:
Variable containing:
-96.7809
[torch.cuda.FloatTensor of size 1 (GPU 0)]
我想做以下事情:
import math
x=preds[4,4]
y=maths.exp(x)
z= y / (y+1)
然而,当我做:
y=maths.exp(x)
我收到以下错误:
math.exp(preds[4,4])
TypeError: a float is required
如何将火炬变量转换为浮点数以便能够执行这些操作?
谢谢
答案 0 :(得分:2)
索引Variable
对象不会将其转换为标量。它仍然是Variable
个对象。但是为numpy array
建立索引的确如此。因此,将Variable
对象转换为numpy然后以您希望的方式编制索引应该可以解决问题。
但是将Variable
转换为numpy时会有一些小的陷阱。
如果preds
是存储在cpu内存中的Variable
,则可以执行此操作。
nparr = preds.data.numpy()
x = nparr[4, 4]
但是,如果preds
在gpu内存中,则必须首先将Variable
转换为cpu内存,然后再将其转换为numpy对象,如下所示:
preds = preds.cpu()
然后按照上面的步骤进行操作。
nparr = preds.data.numpy()
x = nparr[4, 4]
在这两种情况下x
都是标量(在你的情况下是浮点数),你可以在你选择的任何数学运算中使用它。
是的,@ mememex是对的,您也可以直接索引tensor
中包含的Variable
来提取任何给定索引的标量值。
像这样:
x = preds.data[4, 4]