我有一个flask应用程序,该应用程序首先加载keras模型,然后执行预测功能。根据{{3}},我将张量流图保存在全局变量中,并将相同的图用于预测函数。
def load_model():
load_model_from_file()
global graph
graph = tf.get_default_graph()
def predict():
with graph.as_default():
tagger = Tagger(self.model, preprocessor=self.p)
return tagger.analyze(words)
@app.route('/predict', methods=['GET'])
def load_and_predict():
load_model()
predict()
但是,每当有多个请求发送到服务器时,这都会导致问题。如何使此代码具有线程安全性,或更具体地说,如何在多线程环境中正确使用tensorflow图?
答案 0 :(得分:0)
通常,您应该在使用tensorflow中的线程时使用会话。
intra_parallel_thread_tf = 1
inter_parallel_thread_tf = 1
session_conf = tf.ConfigProto(intra_op_parallelism_threads=intra_parallel_thread_tf,
inter_op_parallelism_threads=inter_parallel_thread_tf)
tf.Session(graph=tf.get_default_graph(), config=session_conf)
GRAPH = tf.get_default_graph()
但这很笼统。这也取决于您得到的错误。
答案 1 :(得分:0)
您可以使其与锁同步。
import threading
lock = threading.Lock()
def load_and_predict():
with lock:
load_model()
predict()