是否有任何干净的方法来设置numpy以使用float32值而不是全局使用float64?
答案 0 :(得分:9)
不是我知道的。您需要在调用任何数组的构造函数时显式指定dtype,或者在将数组传递给GPU代码之前将数组转换为float32(使用ndarray.astype方法)(我认为这是问题所针对的? )。如果是你真正担心的GPU情况,我赞成后者 - 如果没有对numpy广播规则和非常精心设计的代码进行非常透彻的理解,尝试保持所有内容都会变得非常烦人。
另一种替代方法可能是创建自己的方法来重载标准的numpy构造函数(所以numpy.zeros,numpy.ones,numpy.empty)。这应该非常接近将所有内容保存在float32中。
答案 1 :(得分:1)
该问题出现在NumPy问题跟踪器上。 The answer是:
没有,抱歉。而且恐怕我们不太可能添加此类内容[。]
答案 2 :(得分:1)
对于每个功能,您都可以通过以下方式重载:
def array(*args, **kwargs):
kwargs.setdefault("dtype", np.float32)
return np.array(*args, **kwargs)
njsmith在github上发布