关于fit_generator()和线程安全

时间:2019-06-04 09:34:20

标签: python multithreading keras multiprocessing thread-safety

上下文

为了在Keras中使用fit_generator(),我使用了一个生成器函数,例如 pseudocode -one:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

在Keras的fit_generator()函数中,我想使用workers=4use_multiprocessing=True-因此,我需要一个线程安全的生成器。

在类似herehere或Keras docs的stackoverflow答案中,我读到了有关为Keras.utils.Sequence()继承这样的类的信息:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...

通过使用Sequences,Keras不会使用多个工作和多重处理发出任何警告;生成器应该是线程安全的。

无论如何,由于我正在使用自定义函数,因此我偶然发现了github上提供的Omer Zohars代码,该代码允许通过添加装饰器使generator()成为线程安全的。 代码如下:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

现在我可以这样做:

@threadsafe_generator
def generator(data):
    ...

问题是:使用此版本的线程安全生成器Keras仍会发出警告,说明在使用workers > 1use_multiprocessing=True时生成器必须是线程安全的,可以通过使用{{1 }}。


我的问题是:

  1. Keras是否仅由于生成器未继承Sequences而发出此警告,还是Keras还检查生成器是否总体上是线程安全的?
  2. 是否使用从Keras-docsSequences版本的线程安全方法?
  3. 是否有其他方法可以导致Keras处理线程安全生成器,这与这两个示例不同?

1 个答案:

答案 0 :(得分:1)

在对此进行研究期间,我遇到了一些信息,回答了我的问题。

  

1。?Keras是否仅由于生成器未继承序列而发出此警告,还是Keras还会检查生成器是否总体上是线程安全的?

来自Keras的gitRepo(training_generators.py),我在第46-52行中发现以下内容:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))

is_sequence()行中取自training_utils.py624-635的定义是:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

乱码这段代码Keras仅检查传递的生成器是否为Keras序列(或者使用Keras的序列API),并且通常不检查生成器是否为线程安全。


  

2。是在使用我选择为线程安全的方法,还是使用Keras-docs中的generatorClass(Sequence)-version?

正如Omer Zohar在gitHub上所展示的,他的装饰器是线程安全的-我看不出为什么它不应该像Keras那样具有线程安全性(即使Keras会像1所示那样发出警告)也没有任何原因。 根据{{​​3}},可以将thread.Lock()的实现视为线程安全的:

  

一个工厂函数,它返回一个新的原始锁对象。 一旦线程获取了它,随后尝试获取它会阻塞,直到被释放;任何线程都可以释放它。

生成器也是可腌制的,可以像这样进行测试(有关更多信息,请参见此SO-Q&A docs):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

继续此操作,我甚至建议在扩展Keras的thread.Lock()时实施Sequence(),例如:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...


  

3。。是否还有其他方法可以使Keras处理线程安全生成器,而这与这两个示例不同?

在研究期间,我没有遇到任何其他方法。 当然,我不能100%肯定地说出来。