使用multiprocessing.pool负载keras模型会导致预测“ ValueError张量Tensor”

时间:2019-05-26 06:46:53

标签: python-3.x tf.keras

我为每个项目保存了1000多个模型。现在,我需要将所有这些模型加载到内存(数据帧)中进行预测。如果仅使用“ for”循环加载这些模型,则每次加载将比以前的模型加载慢3秒。因此,我尝试使用multiprocessing.pool(ThreadPool)。

但是,奇怪的是,使用ThreadPool将导致预测“ ValueError:Tensor Tensor”。如果使用正常加载,则预测很好。

我试过线程也收到错误消息

#following code will lead to ValueError
from multiprocessing.pool import ThreadPool as Pool
def load_model(stock):
    model_pred.at[0, stock] = keras.models.load_model (
        'C:/Users/chenp/Documents/rqpro/models/{}_model.h5'.format (stock))


pool = Pool(processes=16)
for stock in trade_stocks['stock']:
    pool.map (load_model, (stock,))

#Prediction
for stock in trade_stocks['stock']:
    model = model_pred.loc[0, stock]
    prediction = model.predict(pred_data)

#Get following msg:
ValueError: Tensor Tensor("dense_9/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

#Normal code but too low efficient
for stock in trade_stocks['stock']:
    model_pred.at[0, stock] = keras.models.load_model(
           'C:/Users/chenp/Documents/rqpro/models/{}_model.h5'.format(stock))





#Get following msg:
ValueError: Tensor Tensor("dense_9/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

1 个答案:

答案 0 :(得分:0)

这是因为Keras线程不安全。为了解决此问题,请在预测之前使用_make_predict_function()。有关详细的答案,请check