我正在尝试使用TensorFlow Keras模型异步进行预测。该模型可以正常运行而无需多处理,但是,一旦使用mp.Process调用该模型,它就永远不会运行。
import multiprocessing as mp
import tensorflow as tf
import numpy as np
def predict(data):
a = tf.keras.Sequential([tf.keras.layers.Dense(4, input_shape=(16,))])
return a.predict(data)
fake_data = np.zeros((100, 16))
# Works
for i in range(4):
print(predict(fake_data).shape)
# Never finishes running
processes = []
for i in range(4):
p = mp.Process(target=predict, args=(fake_data,))
p.start()
processes.append(p)
for p in processes:
p.join()
为什么会这样?是否有解决方法/我做错了什么?我尝试用predict
包装with tf.device('/cpu:0')
的内部,但问题仍然存在。
编辑:无论TF 1还是2,问题似乎都会发生