我需要多次计算昂贵的计算功能,并希望使用所有处理器核心。如果我同时拥有所有函数参数集,那将会相对简单:我可以使用多处理pool.map。但是我没有同时拥有它们,并且还希望避免为每个函数计算启动单独的过程。因此,我想启动一个工作池(没有问题),从客户端发送作业(使用队列,没有问题),然后将结果返回给客户端(但是如何?)。
更具体地说,我需要在多维网格参数数组上计算我的函数。在我的例子中,可以通过使用我的函数的两个特定属性来避免每个数组点的计算:它对所有参数都是单调的非递减,并且它只需要几个离散值。因此,如果网格上的两个独立点的函数值相等,则对于中间的所有点,它将是相同的。现在我们将网格划分为零件并使用递归。然而,使用递归意味着我不能轻易使用pool.map
实际上我找到了一个有效的解决方案,但我猜可能会有一种更直接,更可靠的方式。
我设置了一个单独的调度程序线程。它通过结果队列从工作者获取所有结果,然后让客户端获取结果。下面是简化(2d)案例的代码。
编辑:对Shihab Shahriar的评论。大部分代码需要完成整个工作,但不直接与问题相关。具体做法是:
contpl()
,sample_f()
是辅助功能。在真正的问题中,而不是sample_f()会有复杂而昂贵的模拟。
recfill()
是递归函数。它需要昂贵计算的结果,并通过调用getz()
接收它们。它将以递归方式在单独的线程中启动自身的实例。其他细节在这里并不重要。
getz()
是我解决方案的重要组成部分:它是上述递归函数和工作池之间的代理。它通过队列id
将作业参数(由task_q
标记)发送到工作池,然后等待来自dispatcher()
的事件并检查计算结果是否已到达,然后返回它们它的来电者。由于父recfill()
个实例在多个线程中运行,getz()
也是如此。
Worker()
- 类的实例在不同的进程中运行,等待从队列task_q
到达的作业,调用“昂贵的函数”,并将结果与id
一起放置标记到result_queue
/ results_q
dispatcher()
在一个单独的线程中运行,从results_q
接收结果并将它们放入一个以id
作为索引的共享字典中。然后向getz()
实例发送一个事件以进行检查,其结果已到达。
main
- 启动工作人员,启动调度员,调用recfill()
并清理。
from concurrent.futures import ThreadPoolExecutor as Pool
import multiprocessing
import threading
import numpy as np
import matplotlib.pyplot as plt
def contpl(p1, p2, Z):
"""
plots color density plot of XYZf
"""
plt.figure(figsize=(5,5))
plt.contourf(p1, p2, Z, 3, cmap='RdYlGn')
plt.show()
def sample_f(x, y):
"""
simple sample monotonous function to plot quater-circles
returns integer values from 0 to 2
"""
return np.round(1.4 * np.sqrt(x**2 + y**2))
def getz(ix, iy, mp_params):
"""
gets the values of the function sample_f for (x, y) values
from the parameter arrays p1, p2 for the indices ix, iy
if the value has not been already computed,
send the job to a worker, then wait until the result is ready
"""
task_q = mp_params["task_q"]
result_event = mp_params["result_event"]
result_dict = mp_params["result_dict"]
num_workers = mp_params["num_workers"]
results_q = mp_params["results_q"]
z = Zarr[ix, iy]
if z >= 0:
return z # nice, the point has already been calculated
# otherwise z is -1 from the array initialisazion
else:
# compute "flattened index" of a point as id
dims = Zarr.shape
id = np.ravel_multi_index((ix, iy), dims)
task_q.put((id, p1[ix, iy], p2[ix, iy])) # send the job, targeted by the id, to workers
# not wait until dispatcher calls
while True:
result_event.wait()
try:
# anything for me?
z = result_dict.pop(id)
result_event.clear()
break
except KeyError:
pass
# now the point is computed, write the value into the array
Zarr[ix, iy] = z
return z
def recfill(ix, iy, mp_params):
"""
recursive function to compute values of a monotonous function
on a 2D square (sub-)array of parameters
"""
(ix0, ix1) = ix # x indices
(iy0, iy1) = iy # y indices
z0 = getz(ix0, iy0, mp_params) # get the bottom left point
# if the array size is one in all dimensions, we reached the recursion limit
if (ix0 == ix1) and (iy0 == iy1):
return
else:
# get the top right point
z1 = getz(ix1, iy1, mp_params)
# if the values for bottom left and top right are equal, they are the same for all
# elements in between
if z0 == z1:
Zarr[ix0:ix1+1, iy0:iy1+1] = z0 # fill in the subarray
return # and we are done for this recursion branch
else:
# divide the sub-array by half in each dimension
xhalf = (ix1 - ix0 + 1) // 2
yhalf = (iy1 - iy0 + 1) // 2
ixlo = (ix0, ix0+xhalf-1)
iylo = (iy0, iy0+yhalf-1)
ixhi = (ix0+xhalf, ix1)
iyhi = (iy0+yhalf, iy1)
# prepare arguments for the map function
l1 = [(ixlo, iylo), (ixlo, iyhi), (ixhi, iylo), (ixhi, iyhi)]
(ixs, iys) = zip(*l1)
mpps = [mp_params]*4
# and now multithreaded recursive call for each quater of the initial sub-array
with Pool() as p:
p.map(recfill, ixs, iys, mpps)
return
class Worker(multiprocessing.Process):
"""
adapted from
https://pymotw.com/3/multiprocessing/communication.html
"""
def __init__(self, mp_params):
multiprocessing.Process.__init__(self)
self.task_queue = mp_params["task_q"]
self.result_queue = mp_params["results_q"]
def run(self):
proc_name = self.name
while True:
job = self.task_queue.get()
if job is None:
print('{}: Exiting'.format(proc_name))
break
(id, x, y) = job
result = sample_f(x, y)
answer = (id, result)
self.result_queue.put(answer)
def dispatcher(mp_params):
"""
receives the computation results from the results queue,
puts them into a shared dictionary,
and notifies all clients per event,
that they should check the dictionary,
if there is anything for them
"""
result_event = mp_params["result_event"]
result_dict = mp_params["result_dict"]
results_q = mp_params["results_q"]
while True:
qitem = results_q.get()
if qitem is not None:
(id, result) = qitem
result_dict[id] = result
result_event.set()
else:
break
if __name__ == '__main__':
result_event = threading.Event()
num_workers = multiprocessing.cpu_count()
task_q = multiprocessing.SimpleQueue()
results_q = multiprocessing.Queue() # why using SimpleQueue here would hang the program?
result_dict = {}
mp_params = {}
mp_params["task_q"] = task_q
mp_params["results_q"] = results_q
mp_params["result_dict"] = result_dict
mp_params["result_event"] = result_event
mp_params["num_workers"] = num_workers
print('Creating {} workers'.format(num_workers))
workers = [Worker(mp_params) for i in range(num_workers)]
for w in workers:
w.start()
# creating dispatcher thread
t = threading.Thread(target=dispatcher, args=(mp_params, ))
t.start()
# creating parameter arrays
arrsize = 128
xvec = np.linspace(0, 1, arrsize)
yvec = np.linspace(0, 1, arrsize)
(p1, p2) = np.meshgrid(xvec, yvec)
# initialize the results array
# our sample_f returns only non-negative values
# therefore fill in with -1 to indicate the values
# which have not been computed yet
Zarr = np.full_like(p1, -1, dtype=np.int8)
# now call our recursive function
# to compute all array values
recfill((0,arrsize-1), (0,arrsize-1), mp_params)
# clean up
for i in range(num_workers):
task_q.put(None) # stop all workers
results_q.put(None) # stop dispatcher
t.join()
# plot the results
contpl(p1, p2, Zarr)
# and check the results by comparing with directly
# calculated values
Z = sample_f(p1, p2)
assert np.all(Z == Zarr)