RuntimeError:长时间在CPUType上不支持_th_normal

时间:2020-04-05 10:54:56

标签: python-3.x pytorch normal-distribution

我正在尝试使用以下方法从正态分布中生成数字:

from torch.distributions import Normal
noise = Normal(th.tensor([0]), th.tensor(3.20))
noise = noise.sample()

但是我收到此错误:RuntimeError: _th_normal not supported on CPUType for Long

1 个答案:

答案 0 :(得分:2)

您的第一个张量th.tensor([0]) torch.Long类型,这是由于根据传递的值会自动进行类型推断,而函数需要floatFloatTensor

您可以通过显式传递0.0来解决此问题:

import torch

noise = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor(3.20))
noise = noise.sample()

更好的是,完全丢弃torch.tensor,在这种情况下,如果可能,Python类型将自动转换为float,因此这也是有效的:

import torch

noise = torch.distributions.Normal(0, 3.20)
noise = noise.sample()

请不要将torch别名为th,这不是官方名称,请使用完全限定的名称,因为这只会使所有人感到困惑。