TensorFlow Keras模型无法预测线程

时间:2020-05-09 15:55:33

标签: python multithreading tensorflow keras multiprocessing

我正在尝试使用TensorFlow Keras模型异步进行预测。该模型可以正常运行而无需多处理,但是,一旦使用mp.Process调用该模型,它就永远不会运行。

import multiprocessing as mp
import tensorflow as tf
import numpy as np

def predict(data):
    a = tf.keras.Sequential([tf.keras.layers.Dense(4, input_shape=(16,))])

    return a.predict(data)

fake_data = np.zeros((100, 16))

# Works
for i in range(4):
    print(predict(fake_data).shape)

# Never finishes running
processes = []
for i in range(4):
    p = mp.Process(target=predict, args=(fake_data,))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

为什么会这样?是否有解决方法/我做错了什么?我尝试用predict包装with tf.device('/cpu:0')的内部,但问题仍然存在。

编辑:无论TF 1还是2,问题似乎都会发生

0 个答案:

没有答案