Tensorflow中的线程安全图用法

时间:2019-01-10 13:25:40

标签: python multithreading tensorflow flask keras

我有一个flask应用程序,该应用程序首先加载keras模型,然后执行预测功能。根据{{​​3}},我将张量流图保存在全局变量中,并将相同的图用于预测函数。

def load_model():
    load_model_from_file()
    global graph
    graph = tf.get_default_graph()

def predict():
    with graph.as_default():
        tagger = Tagger(self.model, preprocessor=self.p)
        return tagger.analyze(words)

@app.route('/predict', methods=['GET'])
def load_and_predict():
    load_model()
    predict()

但是,每当有多个请求发送到服务器时,这都会导致问题。如何使此代码具有线程安全性,或更具体地说,如何在多线程环境中正确使用tensorflow图?

2 个答案:

答案 0 :(得分:0)

通常,您应该在使用tensorflow中的线程时使用会话。

intra_parallel_thread_tf = 1
inter_parallel_thread_tf = 1

session_conf = tf.ConfigProto(intra_op_parallelism_threads=intra_parallel_thread_tf,
                          inter_op_parallelism_threads=inter_parallel_thread_tf)

tf.Session(graph=tf.get_default_graph(), config=session_conf)
GRAPH = tf.get_default_graph()

但这很笼统。这也取决于您得到的错误。

答案 1 :(得分:0)

您可以使其与锁同步。

import threading    
lock = threading.Lock()

def load_and_predict():
     with lock:
        load_model()
        predict()