如何处理TensorFlow会话以同时训练多个Keras模型?

时间:2017-05-13 10:25:26

标签: session tensorflow keras

我需要同时训练多个Keras模型。我正在使用TensorFlow后端。问题是,当我尝试同时训练两个模型时,我得到Attempting to use uninitialized value

错误并不是真正相关,主要问题似乎是Keras强迫两个模型在相同的会话中使用相同的图形创建,因此它会发生冲突。

我是TensorFlow的新手,但我的直觉是答案很简单:你必须为每个Keras模型创建一个不同的会话,并在他们自己的会话中训练它们。有人可以解释一下它将如何完成?

我真的希望能够在仍然使用Keras的同时解决这个问题,而不是在纯TensorFlow中编写所有内容。任何解决方法也会受到赞赏。

2 个答案:

答案 0 :(得分:0)

我使用pythons multiprocessing https://docs.python.org/3.4/library/multiprocessing.html并行训练多个模型。

我有一个带有两个参数的函数,一个输入队列和一个输出队列,这个函数在每个进程中运行。该函数具有以下结构:

def worker(in_queue, out_queue):
    import keras

    while True:
        parameters = in_queue.get()
        network_parameters = parameters[0]
        train_inputs = parameters[1]
        train_outputs = parameters[2]
        test_inputs = parameters[3]
        test_outputs = parameters[4]

        build the network based on the given parameters

        train the network

        test the network if required

        out_queue.put(result)

从主python脚本开始,根据需要启动尽可能多的进程(并创建多个进出队列)。通过调用其队列中的put来向作业添加作业,并通过调用其队列中的get来获得结果。

答案 1 :(得分:0)

是的,Keras自动使用默认会话。 您可以使用main_window.jstf.keras.backend.get_session()来手动设置全局Keras会话(请参见documentation)。

例如:

tf.keras.backend.set_session(sess)