我正在尝试使用multiprocessing.pool.Pool.starmap函数,其中一些参数是由我创建的自定义类对象。调用multiprocessing.pool.Pool.starmap后,出现以下错误: 不能腌制_thread.RLock对象
我一直在尝试使该类具有可挑剔性,但似乎不是解决方案。
import multiprocessing
from multiprocessing.pool import Pool
class SomeClass():
self.docList = [SomeClass2, SomeClass2, SomeClass2,...]
num_of_cpu = int(multiprocessing.cpu_count() / 2)
arguments_in_parallel= []
for key in self.tfidfDic.keys():
arguments_in_parallel.append((self.docList, key))
with Pool(processes=num_of_cpu) as pool:
results = pool.starmap(build_chunks, arguments_in_parallel)
class SomeClass2():
.....
def build_chunks(SomeClass2_list, key):
....
其中self.docList是包含我的类对象的列表。
答案 0 :(得分:0)
我找到了解决方案!而不是使用touple作为参数,然后我将其插入一个字典并在函数中对其进行更改。另外,我将模块更改为multiprocessing.dummy.Pool并使用了map函数。
这是解决方案:
import multiprocessing
from multiprocessing.dummy import Pool as ThreadPool
class SomeClass():
self.docList = [SomeClass2, SomeClass2, SomeClass2,...]
num_of_cpu = int(multiprocessing.cpu_count() / 2)
arguments_in_parallel= []
for key in self.tfidfDic.keys():
arguments_in_parallel.append({"docList": self.docList, "key": key})
# create pool
pool = ThreadPool(num_of_cpu)
results = pool.map(get_silhouette, updates_in_parallel)
# close the pool and wait for the work to finish
pool.close()
pool.join()
class SomeClass2():
.....
def build_chunks(args_dict):
docList = args_dict["docList"]
key = args_dict["key"]
....
答案 1 :(得分:0)
TypeError: cannot pickle '_thread.RLock' object
表示被调用函数的参数之一不是过程安全的。
可以找到here,其中有一个带有MongoClient
的示例。
我在下面举一个来自tensorflow.keras
的对象的例子:
import multiprocessing
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, Dense
from tensorflow.keras.callbacks import ModelCheckpoint
class Blah():
def __init__(self):
self.sequence_length = 9
self.pad = self.sequence_length // 2
self.thres = { "value": 10 }
self.model = None
def return_model(self):
model = Sequential()
model.add(Conv1D(16, 4, strides=1, padding="valid", activation="linear",
input_shape=(self.sequence_length, 1)))
model.add(Dense(512, activation="relu"))
model.add(Dense(3, activation="linear"))
model.compile(optimizer="adam", loss="mse")
return model
def process(self, array):
n_workers = max(( 1, multiprocessing.cpu_count() - 1 ))
with multiprocessing.Pool(n_workers) as workers:
array = np.pad(array, self.pad, mode="constant", constant_values=( 0, 0 ))
pairs = [
( array[i:i+self.sequence_length], self.thres["value"] )
for i in range(array.shape[0] - self.sequence_length + 1)
]
mean_list = workers.starmap(self.active_mean, pairs)
return mean_list
def active_mean(self, window, threshold):
m = 0
activation = np.where(window > threshold)[0]
if len(activation) > 2:
start = activation[0]
stop = activation[-1]
m = window[start:stop].mean()
return m
obj = Blah()
chunk = 100 * np.random.rand(57)
print(obj.process(chunk)) # This will work.
obj.model = obj.return_model() # Adding an unsafe class instance as attribute.
print(obj.process(chunk)) # This doesn't work anymore (ValueError).
在多重处理期间避免隐式调用不安全的类。 上面的示例类变为:
class Blah():
def __init__(self):
self.sequence_length = 9
self.pad = self.sequence_length // 2
self.thres = { "value": 10 }
self.model = None
def return_model(self):
model = Sequential()
model.add(Conv1D(16, 4, strides=1, padding="valid", activation="linear",
input_shape=(self.sequence_length, 1)))
model.add(Dense(512, activation="relu"))
model.add(Dense(3, activation="linear"))
model.compile(optimizer="adam", loss="mse")
return model
def process(self, array):
n_workers = max(( 1, multiprocessing.cpu_count() - 1 ))
with multiprocessing.Pool(n_workers) as workers:
array = np.pad(array, self.pad, mode="constant", constant_values=( 0, 0 ))
pairs = [
( array[i:i+self.sequence_length], self.thres["value"] )
for i in range(array.shape[0] - self.sequence_length + 1)
]
mean_list = workers.starmap(active_mean, pairs)
return mean_list
def active_mean(window, threshold):
""" This function is no longer a method of the class Blah.
We can do this because the function itself does not require any attribute of Blah.
"""
m = 0
activation = np.where(window > threshold)[0]
if len(activation) > 2:
start = activation[0]
stop = activation[-1]
m = window[start:stop].mean()
return m