Numpy标量的二进制运算自动向上转换为float64

时间:2014-09-18 22:16:52

标签: python numpy numba single-precision

我想在np.float32和内置Python int和float之间进行二进制操作(如add和multiply),并获取np.float32作为返回类型。但是,它会自动升级到np.float64。

示例代码:

>>> a = np.float32(5)
>>> a.dtype
dtype('float32')
>>> b = a + 2
>>> b.dtype
dtype('float64')

如果我使用np.float128执行此操作,b也会变为np.float128。这很好,因为它可以保持精确度。但是,在我的示例中,不需要向上转换为np.float64来保持精度,但它仍然会发生。如果我将2.0(一个Python浮点数(64位))添加到一个而不是2,那么转换是有意义的。但即使在这里,我也不想要它。

所以我的问题是:在将二元运算符应用于np.float32和内置Python int / float时,如何更改完成的转换?或者,在所有计算中使单精度成为标准,而不是加倍,也算作解决方案,因为我不需要双精度。其他人已经要求这样做了,似乎找不到任何解决方案。

我知道numpy数组和dtypes。在这里,我得到了想要的行为,因为数组总是保留它的dtype。然而,当我对一个数组的单个元素进行操作时,我得到了不需要的行为。 我对一个解决方案有一个模糊的想法,涉及子类化np.ndarray(或np.float32)并更改__array_priority__的值。到目前为止,我还没能让它发挥作用。

我为什么在意?我正在尝试使用Numba编写一个n体代码。这就是为什么我不能简单地对阵列进行整体操作的原因。将所有np.float64更改为np.float32可使速度提高约2倍,这很重要。 np.float64-casting行为可以完全破坏这种速度,因为我的np.float32数组上的所有操作都以64精度完成,然后下降到32精度。

1 个答案:

答案 0 :(得分:1)

我不确定NumPy的行为,或者你是如何尝试使用Numba的,但明确Numba类型可能有所帮助。例如,如果您执行以下操作:

@jit
def foo(a):
    return a[0] + 2;

a = np.array([3.3], dtype='f4')
foo(a)

在添加操作之前,[0]中的float32值被提升为float64(如果你不介意深入llvm IR,你可以通过使用numba命令运行代码并使用 - 来自己看到这个 - -dump-llvm或--dump-optimized flag:numba --dump-optimized numba_test.py)。但是,通过指定函数签名,包括返回类型为float32:

@jit('f4(f4[:]'))
def foo(a):
    return a[0] + 2;

[0]中的值不会提升为float64,尽管结果会转换为float64,因此当函数返回到Python land时,它可以转换为Python float对象。

如果你可以事先分配一个数组来保存结果,你可以这样做:

@jit
def foo():
    a = np.arange(1000000, dtype='f4')
    result = np.zeros(1000000, dtype='f4')
    for i in range(a.size):
        result[0] = a[0] + 2

即使你自己进行循环,编译代码的性能应该与NumPy ufunc相当,并且不应该对float64进行强制转换(同样,这可以通过查看Numba生成的llvm IR来验证) )。