序列化Numpy数组的意外行为

时间:2016-02-02 10:07:22

标签: python numpy

代码

假设我有:

import numpy
import pickle


class Test():
    def __init__(self):
        self.base = numpy.zeros(6)
        self.view = self.base[-3:]

    def __len__(self):
        return len(self.view)

    def update(self):
        self.view[0] += 1

    def add(self):
        self.view = self.base[-len(self.view) - 1:]
        self.view[0] = 1

    def __repr__(self):
        return str(self.view)


def serialize_data():
    data = Test()
    return pickle.dumps(data)

请注意,类Test只是一个包含view NumPy数组base的类。这个view只是基础中最后N个元素的一部分(初始化时为N == 3)。

Test有一个方法update(),可以将1添加到视图位置0的值,还有一个方法add()可以修改视图大小(N = N + 1)并将位置0的值设置为1

函数serialize_data只创建一个Test()实例,然后使用pickle返回序列化对象。

行为

如果我创建一个局部变量并update两次,add一次,那么一切都按预期工作:

# Local variable
test = Test()
print(test)    # [ 0.  0.  0.]

test.update()
test.update()
print(test)    # [ 2.  0.  0.]

test.add()
print(test)    # [ 1.  2.  0.  0.]

现在,如果我从序列化数据中创建局部变量,那么在执行add后,值2(在调用update两次后设置)似乎丢失了:

# Serialized variable
data = pickle.loads(serialize_data())
print(data)    # [ 0.  0.  0.]

data.update()
data.update()
print(data)    # [ 2.  0.  0.]

data.add()
print(data)    # [ 1.  0.  0.  0.]  <----  This should be [ 1. 2. 0. 0. ] !!!

问题

为什么会发生这种情况,我怎么能避免这种行为?

1 个答案:

答案 0 :(得分:2)

问题在于,在酸洗/去除斑点后,视图不再是基础视图,而是具有“&#39;拥有自己的数据副本。遗憾的是,See here对于如何防止这种情况没有答案。

通过为在取消排版后重新定义视图的类定义__getstate__ and __setstate__方法,可以克服特定问题。

除了视图外,还需要跟踪视图所在的基础部分。我选择使用切片对象,但还有其他方法。没有必要腌制视图本身,因为它将在切片时从切片重建。

class Test():
    def __init__(self):
        self.base = numpy.zeros(6)
        self.slice = slice(-3, self.base.size)
        self.view = self.base[self.slice]

    def __len__(self):
        return len(self.view)

    def update(self):
        self.view[0] += 1

    def add(self):
        self.slice = slice(-len(self.view) - 1, self.base.size)
        self.view = self.base[self.slice]        
        self.view[0] = 1

    def __getstate__(self):
        return {'base': self.base, 'slice': self.slice}

    def __setstate__(self, state):
        self.base = state['base']
        self.slice = state['slice']
        self.view = self.base[self.slice]

    def __repr__(self):
        return str(self.view)