默认情况下,即使数组基础也被pickle,对numpy视图数组进行pickle也会丢失视图关系。我的情况是我有一些复杂的容器对象被腌制。在某些情况下,某些包含的数据是其他一些数据。保存每个视图的独立数组不仅会造成空间损失,而且重新加载的数据也会丢失视图关系。
一个简单的例子是(但在我的例子中,容器比字典更复杂):
import numpy as np
import cPickle
tmp = np.zeros(2)
d1 = dict(a=tmp,b=tmp[:]) # d1 to be saved: b is a view on a
pickled = cPickle.dumps(d1)
d2 = cPickle.loads(pickled) # d2 reloaded copy of d1 container
print 'd1 before:', d1
d1['b'][:] = 1
print 'd1 after: ', d1
print 'd2 before:', d2
d2['b'][:] = 1
print 'd2 after: ', d2
会打印:
d1 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])}
d1 after: {'a': array([ 1., 1.]), 'b': array([ 1., 1.])}
d2 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])}
d2 after: {'a': array([ 0., 0.]), 'b': array([ 1., 1.])} # not a view anymore
我的问题:
(1)有没有办法保存它? (2)(甚至更好)只有在基础被腌制时才有办法做到这一点
对于(1)我认为可以通过更改视图数组的__setstate__
,__reduce_ex_
等来实现某种方式。但是我现在还不满足于这些。对于(2)我不知道。
答案 0 :(得分:7)
这不是在NumPy中完成的,因为挑选基础数组并不总是有意义的,并且pickle不会暴露检查另一个对象是否也被作为其一部分被腌制的能力API。
但是这种检查可以在NumPy数组的自定义容器中完成。例如:
import numpy as np
import pickle
def byte_offset(array, source):
return array.__array_interface__['data'][0] - np.byte_bounds(source)[0]
class SharedPickleList(object):
def __init__(self, arrays):
self.arrays = list(arrays)
def __getstate__(self):
unique_ids = {id(array) for array in self.arrays}
source_arrays = {}
view_tuples = {}
for array in self.arrays:
if array.base is None or id(array.base) not in unique_ids:
# only use views if the base is also being pickled
source_arrays[id(array)] = array
else:
view_tuples[id(array)] = (array.shape,
array.dtype,
id(array.base),
byte_offset(array, array.base),
array.strides)
order = [id(array) for array in self.arrays]
return (source_arrays, view_tuples, order)
def __setstate__(self, state):
source_arrays, view_tuples, order = state
view_arrays = {}
for k, view_state in view_tuples.items():
(shape, dtype, source_id, offset, strides) = view_state
buffer = source_arrays[source_id].data
array = np.ndarray(shape, dtype, buffer, offset, strides)
view_arrays[k] = array
self.arrays = [source_arrays[i]
if i in source_arrays
else view_arrays[i]
for i in order]
# unit tests
def check_roundtrip(arrays):
unpickled_arrays = pickle.loads(pickle.dumps(
SharedPickleList(arrays))).arrays
assert all(a.shape == b.shape and (a == b).all()
for a, b in zip(arrays, unpickled_arrays))
indexers = [0, None, slice(None), slice(2), slice(None, -1),
slice(None, None, -1), slice(None, 6, 2)]
source0 = np.random.randint(100, size=10)
arrays0 = [np.asarray(source0[k1]) for k1 in indexers]
check_roundtrip([source0] + arrays0)
source1 = np.random.randint(100, size=(8, 10))
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers]
check_roundtrip([source1] + arrays1)
这样可以节省大量空间:
source = np.random.rand(1000)
arrays = [source] + [source[n:] for n in range(99)]
print(len(pickle.dumps(arrays, protocol=-1)))
# 766372
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1)))
# 11833