我在这个主题上见过一些blogpost about it few similar,但似乎没有一个能解决我的问题。
我已经训练了Keras模型(仅CPU),并且想使用multithreading.Pool
异步调用预测函数。但是,对predict
的呼叫只是挂起。没有抛出异常或任何异常。从主线程调用它可以正常工作。我曾尝试按照建议使用model._make_predict_function()
,但这并不能解决我的问题。
我已经设置了Jupyter笔记本来重现此内容(Keras == 2.2.4,tensorflow == 1.11.0):
In [1]: from keras.models import Sequential
from keras.layers import Dense
from multiprocessing.pool import Pool
In [2]: # Create sample model from Keras documentation
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
# Generate dummy data
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))
# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32, verbose=0)
In [3]: test_data = np.random.random((1,100))
def predict(model, data):
return model.predict(data)
def do_predict(_=1):
print('Prediction:', predict(model, test_data))
print('Done')
In [4]: do_predict()
Out [4]: Prediction: [[0.5553096]]
Done
In [5]: with Pool(1) as pool:
pool.apply_async(do_predict, [1]).get()
pool.close()
pool.join()
在最后一步,它只是挂起。有人可以帮我找出这里的事吗?不能异步使用predict
吗?