我正在尝试在芹菜的所有分叉工人之间共享TensorFlow模型。 通过这样做,我尝试一次加载模型,并由多个工作人员使用它,从而节省了机器的内存消耗(许多轻量级工作人员加载并使用了一个繁重的模型)
工作程序正在尝试对创建工作程序之前(在celeryinit文件中)初始化的会话变量使用run函数。
我尝试声明全局会话和图形变量,将tf.ConfigProto(use_per_session_threads = True)添加到我的代码中,重置默认图形,完成会话和图形,子函数中的create_local_server,tf.global_variables_initializer(),但相同结果
# Celery executing line:
celery -A Project.celeryinit worker -Q x,y --concurrency=2
# Project.celeryinit ran by celery managing process (prefork) and
# sets the global session object
graph = None
session: tf.Session = None
def create_graph():
global session
global graph
graph = tf.Graph()
with graph.as_default():
text_input = tf.placeholder(dtype=tf.string, shape=[None])
embed = hub.Module(URL_MODULE)
embedded_text = embed(text_input)
init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
session = tf.Session(graph=graph, config=tf.ConfigProto(use_per_session_threads=True))
session.run(init_op)
# Celery task processed by the workers
@staticmethod
@task(name='do_stuff', bind=True, queue='y')
def run_func(input):
output = sess.run(input, feed_dict={setting.text_input: [str(new_label)]})
return output
我希望得到结果或错误,但是进程会无限期地运行。
我知道TensorFlow Serving应该可以解决此问题,但我已经将celery用于其他用例,而宁愿仅在可能的情况下使用。