将任意对象传递给子进程

时间:2018-03-30 20:40:27

标签: python python-3.x multiprocessing

假设我想使用以下类在进程之间共享内存:

class SharedNumpyArray:
    def __init__(self, shape, dtype=np.float64):
        self.shape = shape
        self.dtype = np.dtype(dtype)
        self.buf_size = self.dtype.itemsize * np.prod(shape)
        self.__buf = RawArray('b', self.buf_size)
        self.init_buf()

    def init_buf(self):
        self.buf = np.frombuffer(self.__buf, dtype=self.dtype).reshape(self.shape)

将对象传递给子进程时会发生什么?

将此类的实例传递给子进程时,需要在该进程中再次调用init_buf。有没有办法实现自动化?

此外,如果在将实例传递给子进程时已经定义了__buf会发生什么?

我认为,如果我在将类的实例传递给子进程之前确定self.__buf = None,然后我确保它们中的每一个都调用init_buf,那么一切都应该正常工作,但是有更好的方法吗?

我怀疑不同操作系统之间的实现细节差异很大,因为某些fork进程(例如Linux),而其他操作系统从头开始创建新的(例如Windows)。

1 个答案:

答案 0 :(得分:0)

当父进程调用与子进程关联的start实例的Process方法时,对象实际传递给子进程。

使用multiprocessing自己的pickle变体通过序列化传递对象。这意味着我们需要做的就是在我们的类中添加__reduce__方法。

示例(Python 3):

from multiprocessing import Process
from multiprocessing.sharedctypes import RawArray
import numpy as np


def _rebuild(shape, dtype, _buf):     # must be defined at the module level!
    return SharedNumpyArray(shape, dtype, buffer=_buf)


class SharedNumpyArray(np.ndarray):
    def __new__(subtype, shape, dtype=float, buffer=None, offset=0,
                strides=None, order=None):
        dtype = np.dtype(dtype)
        buf_size = dtype.itemsize * np.prod(shape)
        __buf = buffer or RawArray('b', int(buf_size))
        obj = super().__new__(subtype, shape, dtype, __buf, offset, strides,
                              order)
        obj.__buf = __buf
        return obj

    def __reduce__(self):
        return _rebuild, (self.shape, self.dtype, self.__buf)


class TestProc(Process):
    def __init__(self, shared_buf):
        super().__init__()
        self.shared_buf = shared_buf

    def run(self):
        self.shared_buf[0] = 12
        self.shared_buf[1:5] = 13
        self.shared_buf[1:5] *= 3


def test_proc():
    buf = SharedNumpyArray(1000, dtype=np.byte)
    p = TestProc(buf)
    p.start()
    p.join()
    assert buf[0] == 12
    assert np.all(buf[1:5] == 13 * 3)
    assert len(buf) == 1000


if __name__ == '__main__':
    test_proc()