Pytorch:为什么 torch.where 方法不像 numpy.where 那样工作?

时间:2021-02-01 06:39:39

标签: pytorch

为了使用 Numpy 在随机向量中用某个数字替换正值,用另一个数字替换负值,可以执行以下操作:

npy_p = np.random.randn(4,6)
quant = np.where(npy_p>0, c_plus , np.where(npy_p<0, c_minus , npy_p))

然而,where 中的 Pytorch 方法抛出以下错误:

<块引用>

预期标量类型为双精度但发现为浮点

你能帮我解决这个问题吗?

1 个答案:

答案 0 :(得分:1)

我无法重现此错误,如果您能分享一个失败的具体示例(可能是您尝试用其填充张量的值),也许会更好:

import torch
x = torch.rand(4,6)
res = torch.where(x > 0.3,torch.tensor(0.), torch.where(x < 0.1, torch.tensor(-1.), x))

x 在哪里,它是 dtype float32:

tensor([[0.1391, 0.4491, 0.2363, 0.3215, 0.7740, 0.4879],
        [0.3051, 0.0870, 0.2869, 0.2575, 0.8825, 0.8201],
        [0.4419, 0.1138, 0.0825, 0.9489, 0.1553, 0.6505],
        [0.8376, 0.7639, 0.9291, 0.0865, 0.5984, 0.3953]])

res 是:

tensor([[ 0.1391,  0.0000,  0.2363,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -1.0000,  0.2869,  0.2575,  0.0000,  0.0000],
        [ 0.0000,  0.1138, -1.0000,  0.0000,  0.1553,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.0000,  0.0000,  0.0000]])

问题是因为您在 torch.where 中混合了数据类型,如果您在常量中明确使用与张量相同的数据类型,它可以正常工作。