根据https://docs.python.org/3/library/multiprocessing.html 多处理派生(用于* nix)以创建工作进程来执行任务。我们可以通过在派生之前在模块中设置全局变量来验证这一点。 如果辅助函数导入该模块并找到存在的变量,则已复制过程存储器。就是这样:
import os
def f(x):
import sys
return sys._mypid # <<< value is returned by subprocess!
def set_state():
import sys
sys._mypid = os.getpid()
def g():
from multiprocessing import Pool
pool = Pool(4)
try:
for z in pool.imap(f, range(1000)):
print(z)
finally:
pool.close()
pool.join()
if __name__=='__main__':
set_state()
g()
但是,如果事情以这种方式工作,那么多处理在序列化工作功能f上有什么作用?
在此示例中:
import os
def set_state():
import sys
sys._mypid = os.getpid()
def g():
def f(x):
import sys
return sys._mypid
from multiprocessing import Pool
pool = Pool(4)
try:
for z in pool.imap(f, range(1000)):
print(z)
finally:
pool.close()
pool.join()
if __name__=='__main__':
set_state()
g()
我们得到:
AttributeError: Can't pickle local object 'g.<locals>.f'
Stackoverflow和Internet充满了解决此问题的方法。 (Python的标准pickle函数可以处理函数,但不能处理闭包数据。)
但是我们为什么要来到这里?分叉进程的内存中有f
的写时复制版本。为什么根本需要对其进行序列化?
答案 0 :(得分:-1)
Derp-必须采用这种方式,因为:
pool = Pool(4) <<< processes created here
for z in pool.imap(f, range(1000)): <<< reference to function
仅供参考...希望派生到新进程可以访问该功能(从而避免序列化该功能)的任何人,可以遵循以下模式:
import collections
import multiprocessing as mp
import os
import pickle
import threading
_STATUS_DATA = 0
_STATUS_ERR = 1
_STATUS_POISON = 2
Message = collections.namedtuple(
"Message",
["status",
"payload",
"sequence_id"
]
)
def parallel_map(
target,
args,
num_processes,
inq_maxsize=None,
outq_maxsize=None,
serialize=pickle.dumps,
deserialize=pickle.loads,
start_method="fork",
preserve_order=True,
):
"""
:param target: Target function
:param args: Iterable of single parameter arguments for target.
:param num_processes: Number of processes.
:param inq_maxsize:
:param outq_maxsize:
:param serialize:
:param deserialize:
:param start_method:
:param preserve_order: If true result are returns in the order received by args. Otherwise,
first result is returned first
:return:
"""
if inq_maxsize is None: inq_maxsize=10*num_processes
if outq_maxsize is None: outq_maxsize=10*num_processes
inq = mp.Queue(maxsize=inq_maxsize)
outq = mp.Queue(maxsize=outq_maxsize)
poison = serialize(Message(_STATUS_POISON, None, -1))
deserialize(poison) # Test
def work():
while True:
obj = inq.get()
# print("{} - GET .. OK".format(os.getpid()))
# inq.task_done()
try:
msg = deserialize(obj)
assert isinstance(msg, Message)
if msg.status==_STATUS_POISON:
outq.put(serialize(Message(_STATUS_POISON,None,msg.sequence_id)))
# print("{} - RETURN POISON .. OK".format(os.getpid()))
return
else:
args, kw = msg.payload
result = target(*args,**kw)
outq.put(serialize(Message(_STATUS_DATA,result,msg.sequence_id)))
except Exception as e:
try:
outq.put(serialize(Message(_STATUS_ERR,e,msg.sequence_id)))
except Exception as e2:
try:
outq.put(serialize(Message(_STATUS_ERR,None,-1)))
# outq.put(serialize(1,Exception("Unable to serialize response")))
# TODO. Log exception
except Exception as e3:
pass
if start_method == "thread":
_start_method = threading.Thread
else:
_start_method = mp.get_context('fork').Process
processes = [
_start_method(
target=work,
name="parallel_map.work"
)
for _ in range(num_processes)]
for p in processes:
p.start()
quitting = []
def quit_processes():
if not quitting:
quitting.append(1)
# Send poison pills - kill child processes
for _ in range(num_processes):
inq.put(poison)
nsent = [0]
def send():
# Send the data
for seq_id, arg in enumerate(args):
obj = ((arg,), {})
inq.put(serialize(Message(_STATUS_DATA, obj, seq_id)))
nsent[0] += 1
quit_processes()
# Publish
sender = threading.Thread(
target=send,
name="parallel_map.sender",
daemon=True)
sender.start()
try:
# Consume
nquit = [0]
buffer = {}
nyielded = 0
while True:
result = outq.get() # Waiting here
# outq.task_done()
msg = deserialize(result)
assert isinstance(msg, Message)
if msg.status == _STATUS_POISON:
nquit[0]+=1
# print(">>> QUIT ACK {}".format(nquit[0]))
if nquit[0]>=num_processes:
break
else:
assert msg.sequence_id>=0
if preserve_order:
buffer[msg.sequence_id] = msg
while True:
if nyielded not in buffer:
break
msg = buffer.pop(nyielded)
nyielded += 1
if msg.status==_STATUS_ERR:
if isinstance(msg.payload, Exception):
raise msg.payload
else:
raise Exception("Unexpected exception")
else:
assert msg.status==_STATUS_DATA
yield msg.payload
else:
if msg.status==_STATUS_ERR:
if isinstance(msg.payload, Exception):
raise msg.payload
else:
raise Exception("Unexpected exception")
else:
assert msg.status==_STATUS_DATA
yield msg.payload
# if nyielded == nsent:
# break
except Exception as e:
raise
finally:
if not quitting:
quit_processes()
sender.join()
for p in processes:
p.join()
def f(x):
time.sleep(0.01)
if x ==-1:
raise Exception("Boo")
return x
用法:
def f(x):
time.sleep(0.01)
if x ==-1:
raise Exception("Boo")
return x
for result in parallel_map(target=f, <<< not serialized
args=range(100),
num_processes=8,
start_method="fork"):
pass
...有一个警告:分支时,程序中每个线程都有一只小狗死。