*为什么*多处理会序列化我的函数和闭包?

时间:2019-10-17 23:46:55

标签: python multiprocessing python-multiprocessing

根据https://docs.python.org/3/library/multiprocessing.html 多处理派生(用于* nix)以创建工作进程来执行任务。我们可以通过在派生之前在模块中设置全局变量来验证这一点。 如果辅助函数导入该模块并找到存在的变量,则已复制过程存储器。就是这样:

import os

def f(x):
    import sys
    return sys._mypid  # <<< value is returned by subprocess!


def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

但是,如果事情以这种方式工作,那么多处理在序列化工作功能f上有什么作用?

在此示例中:

import os

def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    def f(x):
        import sys
        return sys._mypid

    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

我们得到:

AttributeError: Can't pickle local object 'g.<locals>.f'

Stackoverflow和Internet充满了解决此问题的方法。 (Python的标准pickle函数可以处理函数,但不能处理闭包数据。)

但是我们为什么要来到这里?分叉进程的内存中有f的写时复制版本。为什么根本需要对其进行序列化?

1 个答案:

答案 0 :(得分:-1)

Derp-必须采用这种方式,因为:

    pool = Pool(4)  <<< processes created here

    for z in pool.imap(f, range(1000)):   <<< reference to function

仅供参考...希望派生到新进程可以访问该功能(从而避免序列化该功能)的任何人,可以遵循以下模式:

import collections
import multiprocessing as mp
import os
import pickle
import threading

_STATUS_DATA = 0
_STATUS_ERR = 1
_STATUS_POISON = 2


Message = collections.namedtuple(
    "Message",
    ["status",
     "payload",
     "sequence_id"
     ]
)

def parallel_map(
        target,
        args,
        num_processes,
        inq_maxsize=None,
        outq_maxsize=None,
        serialize=pickle.dumps,
        deserialize=pickle.loads,
        start_method="fork",
        preserve_order=True,
):
    """
    :param target: Target function
    :param args: Iterable of single parameter arguments for target.
    :param num_processes: Number of processes.
    :param inq_maxsize:
    :param outq_maxsize:
    :param serialize:
    :param deserialize:
    :param start_method:
    :param preserve_order: If true result are returns in the order received by args. Otherwise,
      first result is returned first
    :return:
    """
    if inq_maxsize is None: inq_maxsize=10*num_processes
    if outq_maxsize is None: outq_maxsize=10*num_processes
    inq = mp.Queue(maxsize=inq_maxsize)
    outq = mp.Queue(maxsize=outq_maxsize)
    poison = serialize(Message(_STATUS_POISON, None, -1))
    deserialize(poison) # Test

    def work():
        while True:
            obj = inq.get()
            # print("{} - GET .. OK".format(os.getpid()))
            # inq.task_done()

            try:
                msg = deserialize(obj)
                assert isinstance(msg, Message)
                if msg.status==_STATUS_POISON:
                    outq.put(serialize(Message(_STATUS_POISON,None,msg.sequence_id)))
                    # print("{} - RETURN POISON .. OK".format(os.getpid()))
                    return
                else:
                    args, kw = msg.payload
                    result = target(*args,**kw)
                    outq.put(serialize(Message(_STATUS_DATA,result,msg.sequence_id)))

            except Exception as e:
                try:
                    outq.put(serialize(Message(_STATUS_ERR,e,msg.sequence_id)))
                except Exception as e2:
                    try:
                        outq.put(serialize(Message(_STATUS_ERR,None,-1)))
                        # outq.put(serialize(1,Exception("Unable to serialize response")))
                        # TODO. Log exception
                    except Exception as e3:
                        pass

    if start_method == "thread":
        _start_method = threading.Thread
    else:
        _start_method = mp.get_context('fork').Process

    processes = [
        _start_method(
            target=work,
            name="parallel_map.work"
        )
        for _ in range(num_processes)]

    for p in processes:
        p.start()

    quitting = []
    def quit_processes():
        if not quitting:
            quitting.append(1)
        # Send poison pills - kill child processes
        for _ in range(num_processes):
            inq.put(poison)

    nsent = [0]
    def send():
        # Send the data
        for seq_id, arg in enumerate(args):
            obj = ((arg,), {})
            inq.put(serialize(Message(_STATUS_DATA, obj, seq_id)))
            nsent[0] += 1
        quit_processes()

    # Publish
    sender = threading.Thread(
        target=send,
        name="parallel_map.sender",
        daemon=True)
    sender.start()

    try:
        # Consume
        nquit = [0]
        buffer = {}
        nyielded = 0
        while True:
            result = outq.get() # Waiting here
            # outq.task_done()
            msg = deserialize(result)
            assert isinstance(msg, Message)
            if msg.status == _STATUS_POISON:
                nquit[0]+=1
                # print(">>> QUIT ACK {}".format(nquit[0]))
                if nquit[0]>=num_processes:
                    break
            else:
                assert msg.sequence_id>=0

                if preserve_order:
                    buffer[msg.sequence_id] = msg
                    while True:
                        if nyielded not in buffer:
                            break

                        msg = buffer.pop(nyielded)
                        nyielded += 1
                        if msg.status==_STATUS_ERR:
                            if isinstance(msg.payload, Exception):
                                raise msg.payload
                            else:
                                raise Exception("Unexpected exception")
                        else:
                            assert msg.status==_STATUS_DATA
                            yield msg.payload
                else:
                    if msg.status==_STATUS_ERR:
                        if isinstance(msg.payload, Exception):
                            raise msg.payload
                        else:
                            raise Exception("Unexpected exception")
                    else:
                        assert msg.status==_STATUS_DATA
                        yield msg.payload


                # if nyielded == nsent:
                #     break

    except Exception as e:
        raise
    finally:
        if not quitting:
            quit_processes()
        sender.join()
        for p in processes:
            p.join()


        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

用法:

        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

        for result in parallel_map(target=f,  <<< not serialized
                                   args=range(100),
                                   num_processes=8,
                                   start_method="fork"):
            pass

...有一个警告:分支时,程序中每个线程都有一只小狗死。