为了使用 Numpy
在随机向量中用某个数字替换正值,用另一个数字替换负值,可以执行以下操作:
npy_p = np.random.randn(4,6)
quant = np.where(npy_p>0, c_plus , np.where(npy_p<0, c_minus , npy_p))
然而,where
中的 Pytorch
方法抛出以下错误:
预期标量类型为双精度但发现为浮点
你能帮我解决这个问题吗?
答案 0 :(得分:1)
我无法重现此错误,如果您能分享一个失败的具体示例(可能是您尝试用其填充张量的值),也许会更好:
import torch
x = torch.rand(4,6)
res = torch.where(x > 0.3,torch.tensor(0.), torch.where(x < 0.1, torch.tensor(-1.), x))
x
在哪里,它是 dtype
float32:
tensor([[0.1391, 0.4491, 0.2363, 0.3215, 0.7740, 0.4879],
[0.3051, 0.0870, 0.2869, 0.2575, 0.8825, 0.8201],
[0.4419, 0.1138, 0.0825, 0.9489, 0.1553, 0.6505],
[0.8376, 0.7639, 0.9291, 0.0865, 0.5984, 0.3953]])
而 res
是:
tensor([[ 0.1391, 0.0000, 0.2363, 0.0000, 0.0000, 0.0000],
[ 0.0000, -1.0000, 0.2869, 0.2575, 0.0000, 0.0000],
[ 0.0000, 0.1138, -1.0000, 0.0000, 0.1553, 0.0000],
[ 0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000]])
问题是因为您在 torch.where
中混合了数据类型,如果您在常量中明确使用与张量相同的数据类型,它可以正常工作。