如何在线程之间共享Tensorflow会话?

时间:2019-05-27 15:03:55

标签: python tensorflow parallel-processing celery

我正在尝试在芹菜的所有分叉工人之间共享TensorFlow模型。 通过这样做,我尝试一次加载模型,并由多个工作人员使用它,从而节省了机器的内存消耗(许多轻量级工作人员加载并使用了一个繁重的模型)

工作程序正在尝试对创建工作程序之前(在celeryinit文件中)初始化的会话变量使用run函数。

我尝试声明全局会话和图形变量,将tf.ConfigProto(use_per_session_threads = True)添加到我的代码中,重置默认图形,完成会话和图形,子函数中的create_local_server,tf.global_variables_initializer(),但相同结果

# Celery executing line:
celery -A Project.celeryinit worker -Q x,y --concurrency=2

# Project.celeryinit ran by celery managing process (prefork) and 
# sets the global session object
graph = None
session: tf.Session = None
def create_graph():
    global session
    global graph
    graph = tf.Graph()
    with graph.as_default():
        text_input = tf.placeholder(dtype=tf.string, shape=[None])
        embed = hub.Module(URL_MODULE)
        embedded_text = embed(text_input)
        init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
        session = tf.Session(graph=graph, config=tf.ConfigProto(use_per_session_threads=True))
        session.run(init_op)

# Celery task processed by the workers
@staticmethod
@task(name='do_stuff', bind=True, queue='y')
def run_func(input):
    output = sess.run(input, feed_dict={setting.text_input: [str(new_label)]})
    return output

我希望得到结果或错误,但是进程会无限期地运行。

我知道TensorFlow Serving应该可以解决此问题,但我已经将celery用于其他用例,而宁愿仅在可能的情况下使用。

0 个答案:

没有答案