我使用的是Python和Keras(目前正在使用Theano后端,但我对切换毫无疑问)。我有一个神经网络,我并行加载和处理多个信息源。目前,我已经在一个单独的进程中运行每个进程,并从文件中加载自己的网络副本。这似乎浪费了RAM,所以我认为让一个多线程进程与一个所有线程使用的网络实例相比会更有效率。但是,我想知道Keras是否与后端的线程安全。如果我在不同的线程中同时在两个不同的输入上运行.predict(x)
,我是否会遇到竞争条件或其他问题?
由于
答案 0 :(得分:13)
是的,如果你注意它,Keras是线程安全的。
事实上,在强化学习中,有一种叫做Asynchronous Advantage Actor Critics (A3C)的算法,其中每个代理依赖于相同的神经网络来告诉他们在给定状态下应该做什么。换句话说,每个线程在您的问题中同时调用model.predict
。 Keras的一个示例实现是here。
但是,如果你查看代码,你应该特别注意这一行:
model._make_predict_function() # have to initialize before threading
这在Keras文档中从未被提及,但它必须使其同时工作。简而言之,_make_predict_function
是一个编译predict
函数的函数。在多线程设置中,您必须提前手动调用此函数来编译predict
,否则在您第一次运行它之前不会编译predict
函数,这在许多线程调用时会出现问题它立刻。您可以看到详细解释here。
到目前为止,我还没有遇到过Keras多线程的任何其他问题。
答案 1 :(得分:1)
引用类型fcholet:
_make_predict_function是私有API。我们不建议调用它。
在这里,用户只需简单地先调用“预测”即可。
请注意,不能保证Keras模型是线程安全的。 考虑在每个线程中具有模型的独立副本 用于CPU推断。