我在python的多线程作业中使用keras模型时遇到问题。我的代码几乎与此相似:
def load_my_model():
from keras.models import load_model
global model
model = load_model(path_to_model)
def get_test_data(userid):
....
return test_data_list # returns a list of test data points for user
def predict_for_user(userid):
load_my_model()
test_data_list = get_test_data(userid)
for el in test_data_list:
model.predict(el)
from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list)
pool.close()
pool.join()
过去,仅在for循环中执行此操作就可以正常工作,但是在使用多处理中的池功能时,不行。它引发错误:
无法将feed_dict键解释为张量:Tensor Tensor(“ Placeholder:0”,shape =(300,120),dtype = float32)不是此图的元素。
根据一些在线建议,我将代码更改为此:
def load_my_model():
from keras.models import load_model
import tensorflow as tf
global model
model = load_model(path_to_model)
global graph
graph = tf.get_default_graph()
def get_test_data(userid):
....
return test_data_list # returns a list of test data points for user
def predict_for_user(userid):
load_my_model()
test_data_list = get_test_data(userid)
for el in test_data_list:
with graph.as_default():
model.predict(el)
from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list)
pool.close()
pool.join()
这在某种程度上有所帮助,因为现在该模型可以很好地预测用户,而不会引发错误,但是当它移动到池列表中的下一个用户时,它会像以前一样抛出相同的错误:
无法将feed_dict键解释为张量:Tensor Tensor(“ Placeholder:0”,shape =(300,120),dtype = float32)不是此图的元素。
我确定这与在多线程期间无法正确加载张量有关,但是不确定如何在上面的代码中解决此问题。
任何帮助将不胜感激!