我正在尝试在进程池环境中更改python中的字典,但是当池结束时字典不会更改。
这是问题的最小示例(输出batch_input
都是零,尽管在per_batch_build
内它更改了相关值)
from multiprocessing import Pool, freeze_support
import numpy as np
import itertools
def test_process():
batch_size = 2
batch_input = {'part_evecs': np.zeros((2, 10, 10)),
'model_evecs': np.zeros((2, 10, 10)),
}
batch_model_dist = np.zeros((2, 10, 10))
pool = Pool(4)
batch_output = pool.map(per_batch_build, itertools.izip(itertools.repeat(batch_input),
itertools.repeat(batch_model_dist),
list(range(batch_size))))
pool.close()
pool.join()
return batch_input, batch_model_dist
# @profile
# def per_batch_build(batch_input, batch_model_dist, batch_part_dist, dataset, i_batch):
def per_batch_build(tuple_input):
batch_input, batch_model_dist, i_batch = tuple_input
batch_model_dist[i_batch] = np.ones((10,10))
batch_input['part_evecs'][i_batch] = np.ones((10,10))
batch_input['model_evecs'][i_batch] = np.ones((10,10))
但是很遗憾,batch_input, batch_model_dist, batch_part_dist
都是零,尽管在batch_input
内打印per_batch_build
时不为零。
使用前面讨论中提供的解决方案,结果保持不变(输出数组均为零)
from multiprocessing import Pool, freeze_support, Manager, Array
import numpy as np
import itertools
import ctypes
def test_process():
manager = Manager()
shared_array_base = Array(ctypes.c_double, [0] * (2*10*10))
shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
shared_array = shared_array.reshape((2,10,10))
batch_size = 2
batch_input = manager.dict({'part_evecs': shared_array,
# 'model_evecs': np.zeros((2, 10, 10)),
})
batch_model_dist = np.zeros((2, 10, 10))
pool = Pool(4)
batch_output = pool.map(per_batch_build, itertools.izip(itertools.repeat(batch_input),
itertools.repeat(batch_model_dist),
list(range(batch_size))))
pool.close()
pool.join()
return batch_input, batch_model_dist
# @profile
# def per_batch_build(batch_input, batch_model_dist, batch_part_dist, dataset, i_batch):
def per_batch_build(tuple_input):
batch_input, batch_model_dist, i_batch = tuple_input
batch_model_dist[i_batch] = np.ones((10,10))
batch_input['part_evecs'][i_batch] = np.ones((10,10))
# batch_input['model_evecs'][i_batch] = np.ones((10,10))
答案 0 :(得分:0)
您正在更改在per_batch_build内部创建的对象的副本。您在两个函数中都使用相同的名称命名,因此可能会造成混淆。
添加
print(id(batch_model_dist))
在这两个函数中并自己查看。
[编辑] 我可能还应该链接相关的响应,例如:
Is shared readonly data copied to different processes for multiprocessing?