fit_generator中的Tensorflow / Keras Deadlock,用于具有内部tf模型的数据生成器

时间:2019-07-04 09:19:55

标签: python tensorflow keras deadlock python-multiprocessing

任务

使用keras.model.fit_generator和{worker}在一个本身包含张量流或keras模型的数据生成器上运行use_multiprocessing=True

此问题与以下问题非常相关:https://github.com/tensorflow/tensorflow/issues/5448#issuecomment-258934405

def create_minimal_keras_model():
    ##### Create Model A #####
    in1 = keras.layers.Input(shape=(1,))
    d = keras.layers.Dense(1)(in1)
    a = keras.Model(inputs=in1, outputs=d)
    opt = keras.optimizers.Adam(lr=0.01)
    loss = keras.losses.mse
    a.compile(opt, loss)
    #####

    return a

class TestGenerator(keras.utils.Sequence):
    def __init__(self):
        self.len = int(1e2)
        self.model = None

        # self.init_model()

    def init_model(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.model = create_minimal_keras_model()

    def __len__(self):
        """
        Number of batches for generator.
        """

        return self.len

    def __getitem__(self, index):
        """
        Keras sequence method for generating batches.
        """
        if not self.model:
            self.init_model()

        if self.model:
            with self.graph.as_default():
                res = self.model.predict(np.array([1]))

        return (np.array([index]), np.array([-index/2 + 3]))

错误

培训在第二个阶段开始时停止。

我尝试过的事情

  • 在数据生成器(主进程)初始化时初始化模型
  • 在第一代循环调用(子过程)时初始化模型
  • 调用tf.Session()和其他函数会在第一个时期开始时导致死锁

完整的示例代码:

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import os


def create_minimal_keras_model():
    ##### Create Model A #####
    in1 = keras.layers.Input(shape=(1,))
    d = keras.layers.Dense(1)(in1)
    a = keras.Model(inputs=in1, outputs=d)
    opt = keras.optimizers.Adam(lr=0.01)
    loss = keras.losses.mse
    a.compile(opt, loss)
    #####

    return a

class TestGenerator(keras.utils.Sequence):
    def __init__(self):
        self.len = int(1e2)
        self.model = None

        # self.init_model()

    def init_model(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.model = create_minimal_keras_model()

    def __len__(self):
        """
        Number of batches for generator.
        """

        return self.len

    def __getitem__(self, index):
        """
        Keras sequence method for generating batches.
        """
        if not self.model:
            self.init_model()

        if self.model:
            with self.graph.as_default():
                res = self.model.predict(np.array([1]))

        return (np.array([index]), np.array([-index/2 + 3]))



os.environ['CUDA_VISIBLE_DEVICES'] = ''

a = create_minimal_keras_model()
a.summary()


##########################################
##### Funcions Halt Before 1st Epoch #####
##########################################
# tf.Session()
# a.save_weights('tmp_model_weights.h5')
# a.load_weights('tmp_model_weights.h5')
# a.save('tmp_model.h5')
# keras.models.load_model('tmp_model.h5')
##########################################
##########################################
##########################################


##########################################
##### Functions Causing NO Deadlocks #####
##########################################
tf.get_default_session()
tf.Graph()
keras.__version__
with tf.device('/cpu:0'):
    _ = tf.constant(0)
keras.utils.plot_model(a, to_file=('tmp_plot_model.png'), show_shapes=True)
[a.get_layer(l_name).output for l_name in [a.layers[-1].name]]
_ = keras.backend.variable(4)
_ = keras.backend.image_data_format() 
_ = keras.backend.shape(tf.constant(1, shape=(5,5,5)))
_ = a.layers[0].get_config()
tf.random.set_random_seed(0)
##########################################
##########################################
##########################################


# The training will stop at second epoch
a.fit_generator(generator=TestGenerator(), steps_per_epoch=100, epochs=5, workers=4, use_multiprocessing=True)

问题 在多进程培训中,我有什么选项可以让数据生成器在内部运行张量流模型。

我确定的选项是:

0 个答案:

没有答案