使用python的多处理在keras中并行化模型预测时的执行急切问题

时间:2020-03-30 05:27:43

标签: python numpy tensorflow multiprocessing eager-execution

我正在尝试使用keras在python2中提供的model.predict命令并行执行模型预测。我将tensorflow 1.14.0用于python2。我有5个模型(.h5)文件,并希望predict命令为多个输入并行运行。每个模型对输入顺序进行预测。输入并行输入。事先要加载模型文件。代码如下所示,

import matplotlib as plt
import numpy as np
import cv2
import multiprocessing
import tensorflow as tf
from contextlib import closing
import time

models=['model1.h5','model2.h5','model3.h5','model4.h5','model5.h5']
loaded_models=[]
for model in models:
    loaded_models.append(tf.keras.models.load_model(model))

def prediction(input_tuple):
    inputs,loaded_models=input_tuple
    predops=[]
    for model in loaded_models:
        predops.append(model.predict(inputs).tolist()[0])
    actops=[]
    for predop in predops:
        actops.append(predop.index(max(predop)))
    max_freqq = max(set(actops), key = actops.count) 
    return max_freqq

#....some pre-processing....#

    '''new_all_t is a list which contains tuples and each tuple has inputs from all_t 
    and the list containing loaded models which will be extracted
 in the prediction function.'''

new_all_t=[]
for elem in all_t:
    new_all_t.append((elem,loaded_models))
start_time=time.time()
with closing(multiprocessing.Pool()) as p:
    predops=p.map(prediction,new_all_t)
print('Total time taken: {}'.format(time.time() - start_time))

new_all_t是一个包含元组的列表,每个元组都有来自all_t的输入,该列表包含将在预测函数中提取的已加载模型。但是,我现在遇到以下错误,

Traceback (most recent call last):
  File "trial_mult-ips.py", line 240, in <module>
    predops=p.map(prediction,new_all_t)
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 253, in map
    return self.map_async(func, iterable, chunksize).get()
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 572, in get
    raise self._value
NotImplementedError: numpy() is only available when eager execution is enabled.

这究竟表明了什么?我该如何解决这个问题?

更新: 我在一开始就包含了tf.compat.v1.enable_eager_execution()和tf.compat.v1.enable_v2_behavior()行。现在我得到以下错误,

WARNING:tensorflow:From /home/nick/.local/lib/python2.7/site-packages/tensorflow/python/ops/math_grad.py:1250: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Traceback (most recent call last):
  File "the_other_end-mp.py", line 216, in <module>
    predops=p.map(prediction,modelon)
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 253, in map
    return self.map_async(func, iterable, chunksize).get()
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 572, in get
    raise self._value
ValueError: Resource handles are not convertible to numpy.

我无法解释此错误消息,我该如何解决呢?任何建议都非常感谢!

0 个答案:

没有答案