Keras在多线程环境中无法正常工作

时间:2019-01-25 22:53:32

标签: python tensorflow keras python-multiprocessing

我在python的多线程作业中使用keras模型时遇到问题。我的代码几乎与此相似:

def load_my_model():
  from keras.models import load_model
  global model
  model = load_model(path_to_model)

def get_test_data(userid):
   ....
   return test_data_list  # returns a list of test data points for user

def predict_for_user(userid):
  load_my_model()
  test_data_list = get_test_data(userid)
  for el in test_data_list:
       model.predict(el)

from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list) 
pool.close()
pool.join()

过去,仅在for循环中执行此操作就可以正常工作,但是在使用多处理中的池功能时,不行。它引发错误:

  

无法将feed_dict键解释为张量:Tensor Tensor(“ Placeholder:0”,shape =(300,120),dtype = float32)不是此图的元素。

根据一些在线建议,我将代码更改为此:

def load_my_model():
  from keras.models import load_model
  import tensorflow as tf

  global model
  model = load_model(path_to_model)

  global graph
  graph = tf.get_default_graph()


def get_test_data(userid):
   ....
   return test_data_list  # returns a list of test data points for user

def predict_for_user(userid):
  load_my_model()
  test_data_list = get_test_data(userid)
  for el in test_data_list:
       with graph.as_default():
          model.predict(el)

from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list) 
pool.close()
pool.join()

这在某种程度上有所帮助,因为现在该模型可以很好地预测用户,而不会引发错误,但是当它移动到池列表中的下一个用户时,它会像以前一样抛出相同的错误:

  

无法将feed_dict键解释为张量:Tensor Tensor(“ Placeholder:0”,shape =(300,120),dtype = float32)不是此图的元素。

我确定这与在多线程期间无法正确加载张量有关,但是不确定如何在上面的代码中解决此问题。

任何帮助将不胜感激!

0 个答案:

没有答案