在python的处理池中更改字典值

时间:2019-05-27 13:32:49

标签: python parallel-processing

我正在尝试在进程池环境中更改python中的字典,但是当池结束时字典不会更改。
这是问题的最小示例(输出batch_input都是零,尽管在per_batch_build内它更改了相关值)

from multiprocessing import Pool, freeze_support
import numpy as np
import itertools

def test_process():
    batch_size = 2
    batch_input = {'part_evecs': np.zeros((2, 10, 10)),
                   'model_evecs': np.zeros((2, 10, 10)),
                   }

    batch_model_dist = np.zeros((2, 10, 10))

    pool = Pool(4)
    batch_output = pool.map(per_batch_build, itertools.izip(itertools.repeat(batch_input),
                                                            itertools.repeat(batch_model_dist),
                                                            list(range(batch_size))))
    pool.close()
    pool.join()

    return batch_input, batch_model_dist


# @profile
# def per_batch_build(batch_input, batch_model_dist, batch_part_dist, dataset, i_batch):
def per_batch_build(tuple_input):
    batch_input, batch_model_dist, i_batch = tuple_input

    batch_model_dist[i_batch] = np.ones((10,10))

    batch_input['part_evecs'][i_batch] = np.ones((10,10))
    batch_input['model_evecs'][i_batch] = np.ones((10,10))

但是很遗憾,batch_input, batch_model_dist, batch_part_dist都是零,尽管在batch_input内打印per_batch_build时不为零。

使用前面讨论中提供的解决方案,结果保持不变(输出数组均为零)

from multiprocessing import Pool, freeze_support, Manager, Array
import numpy as np
import itertools
import ctypes

def test_process():
    manager = Manager()

    shared_array_base = Array(ctypes.c_double, [0] * (2*10*10))
    shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
    shared_array = shared_array.reshape((2,10,10))

    batch_size = 2
    batch_input = manager.dict({'part_evecs': shared_array,
                               # 'model_evecs': np.zeros((2, 10, 10)),
                               })


    batch_model_dist = np.zeros((2, 10, 10))

    pool = Pool(4)
    batch_output = pool.map(per_batch_build, itertools.izip(itertools.repeat(batch_input),
                                                            itertools.repeat(batch_model_dist),
                                                            list(range(batch_size))))
    pool.close()
    pool.join()

    return batch_input, batch_model_dist


# @profile
# def per_batch_build(batch_input, batch_model_dist, batch_part_dist, dataset, i_batch):
def per_batch_build(tuple_input):
    batch_input, batch_model_dist, i_batch = tuple_input

    batch_model_dist[i_batch] = np.ones((10,10))

    batch_input['part_evecs'][i_batch] = np.ones((10,10))
    # batch_input['model_evecs'][i_batch] = np.ones((10,10))

1 个答案:

答案 0 :(得分:0)

您正在更改在per_batch_build内部创建的对象的副本。您在两个函数中都使用相同的名称命名,因此可能会造成混淆。

添加 print(id(batch_model_dist)) 在这两个函数中并自己查看。

[编辑] 我可能还应该链接相关的响应,例如:

Is shared readonly data copied to different processes for multiprocessing?