Keras预测不会在芹菜任务中返回

时间:2017-08-02 11:09:23

标签: django tensorflow redis celery keras

以下Keras函数(预测)在同步调用时有效

pred = model.predict(x)

但是从异步任务队列(Celery)中调用时它不起作用。 当异步调用时,Keras预测函数不会返回任何输出。

堆栈是:Django,Celery,Redis,Keras,TensorFlow

2 个答案:

答案 0 :(得分:1)

我从此Blog

获得了参考
  • Tensorflow是特定于线程的数据结构,当您调用 model.predict
  • 时,它们会在后台运行
GRAPH = tf.get_default_graph()
with GRAPH.as_default():
    pred = model.predict
return pred

但是Celery使用流程来管理其所有工人池。因此,由于您需要使用gevent或eventlet库

  

pip安装gevent

现在将celery运行为:

  

芹菜-我的网站工作人员-pool gevent -l信息

答案 1 :(得分:0)

我碰到了这个完全相同的问题,而那真是个兔子洞。想要在这里发布我的解决方案,因为这可能会节省某人一天的工作:

TensorFlow线程特定的数据结构

在TensorFlow中,当您调用model.predict(或keras.models.load_modelkeras.backend.clear_session或与之交互的几乎所有其他函数时,有两种关键的数据结构在后台运行。 TensorFlow后端):

在文档中没有明确挖掘就不清楚的是,会话和图形都是当前线程的属性。请参阅API文档herehere

在不同线程中使用TensorFlow模型

自然需要一次加载模型,然后以后多次调用.predict()

from keras.models import load_model

MY_MODEL = load_model('path/to/model/file')

def some_worker_function(inputs):
    return MY_MODEL.predict(inputs)

在像Celery这样的Web服务器或工作者池上下文中,这意味着您在导入包含load_model行的模块时将加载模型,然后另一个线程将执行some_worker_function,并运行对包含Keras模型的全局变量进行预测。但是,尝试在装入不同线程的模型上运行预测会产生“张量不是该图的元素”错误。感谢与该主题相关的几篇SO帖子,例如ValueError: Tensor Tensor(...) is not an element of this graph. When using global variable keras model。为了使它起作用,您需要保留使用的TensorFlow图-正如我们之前所看到的,该图是当前线程的属性。更新后的代码如下:

from keras.models import load_model
import tensorflow as tf

MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()

def some_worker_function(inputs):
    with MY_GRAPH.as_default():
        return MY_MODEL.predict(inputs)

这里有些令人惊讶的变化是:如果您使用Thread,上面的代码就足够了,但是如果您使用Process es,则可以无限期地挂起。 ,Celery使用流程来管理其所有工作人员池。因此,此时,仍然在Celery上无法正常工作。

为什么这只能在Thread上使用?

在Python中,Thread与父进程共享相同的全局执行上下文。来自Python _thread docs

  

该模块提供了用于处理多个线程(也称为轻量级进程或任务)的低级原语,这些线程是共享其全局数据空间的多个控制线程。

由于线程不是实际的独立进程,因此它们使用相同的python解释器,因此要受到臭名昭著的Global Interpeter Lock(GIL)的约束。对于这次调查而言,也许更重要的是,它们与父级共享全局数据空间。

与此相反,Process是程序产生的 actual 新进程。这意味着:

  • 新的Python解释器实例(没有GIL)
  • 全局地址空间已重复

请注意此处的区别。尽管Thread可以访问共享的单个全局Session变量(内部存储在Keras的tensorflow_backend模块中),但是Process s具有Session变量的副本。

我对这个问题的最佳理解是,Session变量应该代表客户机(进程)和TensorFlow运行时之间的唯一连接,但是由于在派生过程中被复制,因此此连接信息不正确调整。这会导致TensorFlow在尝试使用以其他过程创建的Session时挂起。如果有人对TensorFlow的工作原理有更深入的了解,我很乐意听到!

解决方案/解决方法

我一直在调整Celery,以便它使用Thread s而不是Process es进行池化。这种方法有一些缺点(请参见上面的GIL注释),但这使我们只能加载一次模型。由于TensorFlow运行时会最大化所有CPU内核,因此我们实际上并没有CPU限制(因为它不是用Python编写的,因此可以避开GIL)。您必须为Celery提供一个单独的库才能进行基于线程的池化。该文档提出了两个选择:geventeventlet。然后,您将选择的库传递给工作程序via the --pool command line argument

或者,(如您已经发现@ pX0r)似乎其他Keras后端(例如Theano)没有此问题。这是有道理的,因为这些问题与TensorFlow实施细节紧密相关。我个人尚未尝试过Theano,所以您的里程可能会有所不同。

我知道这个问题是在不久前发布的,但是这个问题仍然存在,因此希望可以对某人有所帮助!