如何在初始化numpy对象数组时避免额外的float对象副本

时间:2018-08-08 14:14:12

标签: python numpy numpy-ndarray

我天真地假设通过省略号[...]来分配值,例如

a = np.empty(N, dtype=np.object)
a[...] = 0.0

基本上是以下朴素循环的更快版本:

def slow_assign_1d(a, value):
    for i in range(len(a)):
        a[i] = value

但是事实并非如此。这是不同行为的示例:

>>> a=np.empty(2, dtype=np.object)
>>> a[...] = 0.0
>>> a[0] is a[1]
False

对象0.0似乎已被克隆。但是,当我使用朴素的慢版本时:

>>> a=np.empty(2, dtype=np.object)
>>> slow_assign(a, 0.0)
>>> a[0] is a[1]
True

所有元素都是“相同”的。

非常有趣的是,例如,使用自定义类,可以观察到期望的省略号行为:

>>> class A:
       pass
>>> a[...]=A()
>>> a[0] is a[1]
True 

为什么要对float进行这种“特殊”处理,并且有没有一种方法可以使用float值进行快速初始化而又不产生副本?

注意:np.full(...)a[:]的行为与a[...]相同:克隆了对象0.0 /创建了其副本。


编辑:正如@Till Hoffmann所指出的那样,字符串和整数的期望行为仅对于小整数(-5 ... 255)和短字符串(一个字符)而言是这样,因为他们来自一个游泳池,那里从来没有一个这样的物体。

>>> a[...] = 1         # or 'a'
>>> a[0] is a[1]
True
>>> a[...] = 1000      # or 'aa'
>>> a[0] is a[1]
False

似乎“期望的行为”仅适用于numpy类型,无法将其转换为某些类型,例如:

>>> class A(float): # can be downcasted to a float
>>>     pass
>>> a[...]=A()
>>> a[0] is a[1]
False

甚至更多,a[0]不再是A类型,而是float类型。

2 个答案:

答案 0 :(得分:3)

这实际上是整数而不是浮点数的问题。尤其是, “小”整数在python中缓存,因此它们全部都引用回相同的内存,因此具有相同的pos_tag(),因此与id运算符相比,它们是相同的。对于浮点数则不是这样。有关更深入的讨论,请参见"is" operator behaves unexpectedly with integers。有关“小”的正式定义,请参见https://docs.python.org/3/c-api/long.html#c.PyLong_FromLong


关于isA继承的特定示例,numpy documentation指出

  

请注意,如果将较高类型分配给较低类型[...]

,分配可能会导致更改

有人可能会争辩说,在上面提供的示例情况下,没有发生将较高类型分配给较低类型的情况,因为float应该是最通用的类​​型。但是,检查数组元素的类型后,很明显,在使用np.object分配进行分配时,类型会向下转换为float

[...]

顺便说一句:您可能无法通过存储对感兴趣对象的引用来节省大量内存,除非单个对象非常大。例如。存储单个精度浮点数比存储指向它的指针(在64位系统上)便宜。如果您的对象确实很大,则它们(可能)无法向下转换为原始类型,因此问题一开始就不太可能出现。

答案 1 :(得分:0)

此行为是一个小错误:https://github.com/numpy/numpy/issues/11701

因此,可能需要使用一种变通办法,直到修复该错误为止。我最终使用了用cython实现/编译的天真的慢版本,例如,这里是一维和np.full的例子:

%%cython
cimport numpy as np
import numpy as np
def cy_full(Py_ssize_t n, object obj):
    cdef np.ndarray[dtype=object] res = np.empty(n, dtype=object)
    cdef Py_ssize_t i
    for i in range(n):
        res[i]=obj
    return res

a=cy_full(5, np.nan)

a[0] is a[4]  # True as expected!

np.full相比,也没有性能劣势:

%timeit cy_full(1000, np.nan)
# 8.22 µs ± 39.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.full(1000, np.nan, dtype=np.object)
# 22.3 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)