如何将作业从多个客户端线程发送到多处理器工作器池并返回结果(python)?

时间:2017-11-04 12:17:16

标签: python multithreading recursion python-multiprocessing python-multithreading

我需要多次计算昂贵的计算功能,并希望使用所有处理器核心。如果我同时拥有所有函数参数集,那将会相对简单:我可以使用多处理pool.map。但是我没有同时拥有它们,并且还希望避免为每个函数计算启动单独的过程。因此,我想启动一个工作池(没有问题),从客户端发送作业(使用队列,没有问题),然后将结果返回给客户端(但是如何?)。

更具体地说,我需要在多维网格参数数组上计算我的函数。在我的例子中,可以通过使用我的函数的两个特定属性来避免每个数组点的计算:它对所有参数都是单调的非递减,并且它只需要几个离散值。因此,如果网格上的两个独立点的函数值相等,则对于中间的所有点,它将是相同的。现在我们将网格划分为零件并使用递归。然而,使用递归意味着我不能轻易使用pool.map

实际上我找到了一个有效的解决方案,但我猜可能会有一种更直接,更可靠的方式。

我设置了一个单独的调度程序线程。它通过结果队列从工作者获取所有结果,然后让客户端获取结果。下面是简化(2d)案例的代码。

编辑:对Shihab Shahriar的评论。大部分代码需要完成整个工作,但不直接与问题相关。具体做法是:

contpl()sample_f()是辅助功能。在真正的问题中,而不是sample_f()会有复杂而昂贵的模拟。

recfill()是递归函数。它需要昂贵计算的结果,并通过调用getz()接收它们。它将以递归方式在单独的线程中启动自身的实例。其他细节在这里并不重要。

getz()是我解决方案的重要组成部分:它是上述递归函数和工作池之间的代理。它通过队列id将作业参数(由task_q标记)发送到工作池,然后等待来自dispatcher()的事件并检查计算结果是否已到达,然后返回它们它的来电者。由于父recfill()个实例在多个线程中运行,getz()也是如此。

Worker() - 类的实例在不同的进程中运行,等待从队列task_q到达的作业,调用“昂贵的函数”,并将结果与​​id一起放置标记到result_queue / results_q

dispatcher()在一个单独的线程中运行,从results_q接收结果并将它们放入一个以id作为索引的共享字典中。然后向getz()实例发送一个事件以进行检查,其结果已到达。

main - 启动工作人员,启动调度员,调用recfill()并清理。

from concurrent.futures import ThreadPoolExecutor as Pool
import multiprocessing
import threading

import numpy as np

import matplotlib.pyplot as plt


def contpl(p1, p2, Z):
    """
    plots color density plot of XYZf
    """
    plt.figure(figsize=(5,5))
    plt.contourf(p1, p2, Z, 3, cmap='RdYlGn')
    plt.show()


def sample_f(x, y):
    """
    simple sample monotonous function to plot quater-circles
    returns integer values from 0 to 2
    """
    return np.round(1.4 * np.sqrt(x**2 + y**2))


def getz(ix, iy, mp_params):
    """
    gets the values of the function sample_f for (x, y) values
    from the parameter arrays p1, p2 for the indices ix, iy
    if the value has not been already computed,
    send the job to a worker, then wait until the result is ready
    """
    task_q = mp_params["task_q"]
    result_event = mp_params["result_event"]
    result_dict = mp_params["result_dict"]
    num_workers = mp_params["num_workers"]
    results_q = mp_params["results_q"]

    z = Zarr[ix, iy]
    if z >= 0:
        return z     # nice, the point has already been calculated
    # otherwise z is -1 from the array initialisazion
    else:
        # compute "flattened index" of a point as id
        dims = Zarr.shape
        id = np.ravel_multi_index((ix, iy), dims)

        task_q.put((id, p1[ix, iy], p2[ix, iy])) # send the job, targeted by the id, to workers

        # not wait until dispatcher calls
        while True:
            result_event.wait()
            try:
                # anything for me?
                z = result_dict.pop(id)
                result_event.clear()
                break
            except KeyError:
                pass
        # now the point is computed, write the value into the array
        Zarr[ix, iy] = z
        return z

def recfill(ix, iy, mp_params):
    """
    recursive function to compute values of a monotonous function
    on a 2D square (sub-)array of parameters
    """

    (ix0, ix1) = ix # x indices
    (iy0, iy1) = iy # y indices

    z0 = getz(ix0, iy0, mp_params) # get the bottom left point

    # if the array size is one in all dimensions, we reached the recursion limit
    if (ix0 == ix1) and (iy0 == iy1):
        return

    else:
        # get the top right point
        z1 = getz(ix1, iy1, mp_params)
        # if the values for bottom left and top right are equal, they are the same for all
        # elements in between
        if z0 == z1:
            Zarr[ix0:ix1+1, iy0:iy1+1] = z0 # fill in the subarray
            return # and we are done for this recursion branch

        else:
            # divide the sub-array by half in each dimension
            xhalf = (ix1 - ix0 + 1) // 2
            yhalf = (iy1 - iy0 + 1) // 2

            ixlo = (ix0, ix0+xhalf-1)
            iylo = (iy0, iy0+yhalf-1)
            ixhi = (ix0+xhalf, ix1)
            iyhi = (iy0+yhalf, iy1)

            # prepare arguments for the map function
            l1 = [(ixlo, iylo), (ixlo, iyhi), (ixhi, iylo), (ixhi, iyhi)]
            (ixs, iys) = zip(*l1)
            mpps = [mp_params]*4

            # and now multithreaded recursive call for each quater of the initial sub-array
            with Pool() as p:
                p.map(recfill, ixs, iys, mpps)

            return


class Worker(multiprocessing.Process):
    """
    adapted from
    https://pymotw.com/3/multiprocessing/communication.html
    """
    def __init__(self, mp_params):
        multiprocessing.Process.__init__(self)
        self.task_queue = mp_params["task_q"]
        self.result_queue = mp_params["results_q"]

    def run(self):
        proc_name = self.name
        while True:
            job = self.task_queue.get()
            if job is None:
                print('{}: Exiting'.format(proc_name))
                break

            (id, x, y) = job
            result = sample_f(x, y)

            answer = (id, result)
            self.result_queue.put(answer)


def dispatcher(mp_params):
    """
    receives the computation results from the results queue,
    puts them into a shared dictionary,
    and notifies all clients per event,
    that they should check the dictionary,
    if there is anything for them
    """
    result_event = mp_params["result_event"]
    result_dict = mp_params["result_dict"]
    results_q = mp_params["results_q"]

    while True:
        qitem = results_q.get()
        if qitem is not None:
            (id, result) = qitem
            result_dict[id] = result
            result_event.set()
        else:
            break


if __name__ == '__main__':

    result_event = threading.Event()
    num_workers = multiprocessing.cpu_count()
    task_q = multiprocessing.SimpleQueue()
    results_q = multiprocessing.Queue() # why using SimpleQueue here would hang the program?
    result_dict = {}


    mp_params = {}
    mp_params["task_q"] = task_q
    mp_params["results_q"] = results_q
    mp_params["result_dict"] = result_dict
    mp_params["result_event"] = result_event
    mp_params["num_workers"] = num_workers


    print('Creating {} workers'.format(num_workers))

    workers = [Worker(mp_params) for i in range(num_workers)]
    for w in workers:
        w.start()

    # creating dispatcher thread
    t = threading.Thread(target=dispatcher, args=(mp_params, ))
    t.start()

    # creating parameter arrays
    arrsize = 128
    xvec = np.linspace(0, 1, arrsize)
    yvec =    np.linspace(0, 1, arrsize)
    (p1, p2) = np.meshgrid(xvec, yvec)

    # initialize the results array
    # our sample_f returns only non-negative values
    # therefore fill in with -1 to indicate the values
    # which have not been computed yet
    Zarr = np.full_like(p1, -1, dtype=np.int8)

    # now call our recursive function
    # to compute all array values
    recfill((0,arrsize-1), (0,arrsize-1), mp_params)

    # clean up
    for i in range(num_workers):
        task_q.put(None) # stop all workers

    results_q.put(None) # stop dispatcher
    t.join()

    # plot the results
    contpl(p1, p2, Zarr)


    # and check the results by comparing with directly
    # calculated values
    Z = sample_f(p1, p2)
    assert np.all(Z == Zarr)

0 个答案:

没有答案