使用multiprocessing.pool.Pool.starmap获得“无法腌制_thread.RLock对象”

时间:2019-08-26 20:33:57

标签: python-3.x multithreading pool

我正在尝试使用multiprocessing.pool.Pool.starmap函数,其中一些参数是由我创建的自定义类对象。调用multiprocessing.pool.Pool.starmap后,出现以下错误: 不能腌制_thread.RLock对象

我一直在尝试使该类具有可挑剔性,但似乎不是解决方案。

import multiprocessing
from multiprocessing.pool import Pool

class SomeClass():
    self.docList = [SomeClass2, SomeClass2, SomeClass2,...]
    num_of_cpu = int(multiprocessing.cpu_count() / 2)
    arguments_in_parallel= []


    for key in self.tfidfDic.keys():
        arguments_in_parallel.append((self.docList, key))

    with Pool(processes=num_of_cpu) as pool:
        results = pool.starmap(build_chunks, arguments_in_parallel)

class SomeClass2():
    .....

def build_chunks(SomeClass2_list, key):
    ....

其中self.docList是包含我的类对象的列表。

2 个答案:

答案 0 :(得分:0)

我找到了解决方案!而不是使用touple作为参数,然后我将其插入一个字典并在函数中对其进行更改。另外,我将模块更改为multiprocessing.dummy.Pool并使用了map函数。

这是解决方案:

import multiprocessing
from multiprocessing.dummy import Pool as ThreadPool

class SomeClass():
    self.docList = [SomeClass2, SomeClass2, SomeClass2,...]
    num_of_cpu = int(multiprocessing.cpu_count() / 2)
    arguments_in_parallel= []


    for key in self.tfidfDic.keys():
        arguments_in_parallel.append({"docList": self.docList, "key": key})

     # create pool
     pool = ThreadPool(num_of_cpu)
     results = pool.map(get_silhouette, updates_in_parallel)

     # close the pool and wait for the work to finish
     pool.close()
     pool.join()


class SomeClass2():
    .....

def build_chunks(args_dict):
    docList = args_dict["docList"]
    key = args_dict["key"]
    ....

答案 1 :(得分:0)

TypeError: cannot pickle '_thread.RLock' object表示被调用函数的参数之一不是过程安全的。

可以找到here,其中有一个带有MongoClient的示例。 我在下面举一个来自tensorflow.keras的对象的例子:

虚拟课

import multiprocessing
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, Dense
from tensorflow.keras.callbacks import ModelCheckpoint


class Blah():
    def __init__(self):
        self.sequence_length = 9
        self.pad = self.sequence_length // 2
        self.thres = { "value": 10 }
        self.model = None

    def return_model(self):
        model = Sequential()
        model.add(Conv1D(16, 4, strides=1, padding="valid", activation="linear",
                  input_shape=(self.sequence_length, 1)))
        model.add(Dense(512, activation="relu"))
        model.add(Dense(3, activation="linear"))
        model.compile(optimizer="adam", loss="mse")
        return model


    def process(self, array):
        n_workers = max(( 1, multiprocessing.cpu_count() - 1 ))
        with multiprocessing.Pool(n_workers) as workers:
            array = np.pad(array, self.pad, mode="constant", constant_values=( 0, 0 ))
            pairs = [
                    ( array[i:i+self.sequence_length], self.thres["value"] )
                    for i in range(array.shape[0] - self.sequence_length + 1)
            ]
            mean_list = workers.starmap(self.active_mean, pairs)

        return mean_list

    def active_mean(self, window, threshold):
        m = 0
        activation = np.where(window > threshold)[0]
        if len(activation) > 2:
            start = activation[0]
            stop = activation[-1]
            m = window[start:stop].mean()

        return m

评估

obj = Blah()
chunk = 100 * np.random.rand(57)
print(obj.process(chunk))          # This will work.
obj.model = obj.return_model()     # Adding an unsafe class instance as attribute.
print(obj.process(chunk))          # This doesn't work anymore (ValueError).

解决方法

在多重处理期间避免隐式调用不安全的类。 上面的示例类变为:

class Blah():
    def __init__(self):
        self.sequence_length = 9
        self.pad = self.sequence_length // 2
        self.thres = { "value": 10 }
        self.model = None

    def return_model(self):
        model = Sequential()
        model.add(Conv1D(16, 4, strides=1, padding="valid", activation="linear",
                  input_shape=(self.sequence_length, 1)))
        model.add(Dense(512, activation="relu"))
        model.add(Dense(3, activation="linear"))
        model.compile(optimizer="adam", loss="mse")
        return model


    def process(self, array):
        n_workers = max(( 1, multiprocessing.cpu_count() - 1 ))
        with multiprocessing.Pool(n_workers) as workers:
            array = np.pad(array, self.pad, mode="constant", constant_values=( 0, 0 ))
            pairs = [
                    ( array[i:i+self.sequence_length], self.thres["value"] )
                    for i in range(array.shape[0] - self.sequence_length + 1)
            ]
            mean_list = workers.starmap(active_mean, pairs)

        return mean_list


def active_mean(window, threshold):
    """ This function is no longer a method of the class Blah.
        We can do this because the function itself does not require any attribute of Blah.
    """
    m = 0
    activation = np.where(window > threshold)[0]
    if len(activation) > 2:
        start = activation[0]
        stop = activation[-1]
        m = window[start:stop].mean()

    return m