酸洗时保留numpy视图

时间:2012-12-06 15:12:31

标签: python numpy view pickle

默认情况下,即使数组基础也被pickle,对numpy视图数组进行pickle也会丢失视图关系。我的情况是我有一些复杂的容器对象被腌制。在某些情况下,某些包含的数据是其他一些数据。保存每个视图的独立数组不仅会造成空间损失,而且重新加载的数据也会丢失视图关系。

一个简单的例子是(但在我的例子中,容器比字典更复杂):

import numpy as np
import cPickle

tmp = np.zeros(2)
d1 = dict(a=tmp,b=tmp[:])    # d1 to be saved: b is a view on a

pickled = cPickle.dumps(d1)
d2 = cPickle.loads(pickled)  # d2 reloaded copy of d1 container

print 'd1 before:', d1
d1['b'][:] = 1
print 'd1 after: ', d1

print 'd2 before:', d2
d2['b'][:] = 1
print 'd2 after: ', d2

会打印:

d1 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d1 after:  {'a': array([ 1.,  1.]), 'b': array([ 1.,  1.])}
d2 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d2 after:  {'a': array([ 0.,  0.]), 'b': array([ 1.,  1.])} # not a view anymore

我的问题:

(1)有没有办法保存它?  (2)(甚至更好)只有在基础被腌制时才有办法做到这一点

对于(1)我认为可以通过更改视图数组的__setstate____reduce_ex_等来实现某种方式。但是我现在还不满足于这些。对于(2)我不知道。

1 个答案:

答案 0 :(得分:7)

这不是在NumPy中完成的,因为挑选基础数组并不总是有意义的,并且pickle不会暴露检查另一个对象是否也被作为其一部分被腌制的能力API。

但是这种检查可以在NumPy数组的自定义容器中完成。例如:

import numpy as np
import pickle

def byte_offset(array, source):
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0]

class SharedPickleList(object):
    def __init__(self, arrays):
        self.arrays = list(arrays)

    def __getstate__(self):
        unique_ids = {id(array) for array in self.arrays}
        source_arrays = {}
        view_tuples = {}
        for array in self.arrays:
            if array.base is None or id(array.base) not in unique_ids:
                # only use views if the base is also being pickled
                source_arrays[id(array)] = array
            else:
                view_tuples[id(array)] = (array.shape,
                                          array.dtype,
                                          id(array.base),
                                          byte_offset(array, array.base),
                                          array.strides)
        order = [id(array) for array in self.arrays]
        return (source_arrays, view_tuples, order)

    def __setstate__(self, state):
        source_arrays, view_tuples, order = state
        view_arrays = {}
        for k, view_state in view_tuples.items():
            (shape, dtype, source_id, offset, strides) = view_state
            buffer = source_arrays[source_id].data
            array = np.ndarray(shape, dtype, buffer, offset, strides)
            view_arrays[k] = array
        self.arrays = [source_arrays[i]
                       if i in source_arrays
                       else view_arrays[i]
                       for i in order]

# unit tests
def check_roundtrip(arrays):
    unpickled_arrays = pickle.loads(pickle.dumps(
        SharedPickleList(arrays))).arrays
    assert all(a.shape == b.shape and (a == b).all()
               for a, b in zip(arrays, unpickled_arrays))

indexers = [0, None, slice(None), slice(2), slice(None, -1),
            slice(None, None, -1), slice(None, 6, 2)]

source0 = np.random.randint(100, size=10)
arrays0 = [np.asarray(source0[k1]) for k1 in indexers]
check_roundtrip([source0] + arrays0)

source1 = np.random.randint(100, size=(8, 10))
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers]
check_roundtrip([source1] + arrays1)

这样可以节省大量空间:

source = np.random.rand(1000)
arrays = [source] + [source[n:] for n in range(99)]
print(len(pickle.dumps(arrays, protocol=-1)))
# 766372
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1)))
# 11833