如何使用顺序Keras模型和多处理并行进行预测?

时间:2020-08-16 15:38:28

标签: python tensorflow keras multiprocessing

我正在尝试解决二进制分类问题。因此,我在Keras(带有TensorFlow后端)中创建了一个模型,并在CPU上训练了该模型。我使用Keras API将模型保存为TensorFlow SavedModel格式。我正在将Kaggle Kernels与Python 3.7.6,Keras 2.4.3和TensorFlow 2.3.0一起使用。

这是我与一些模型数据一起使用的简化代码(实际上,我正在GPU上训练我的模型,但我认为我所面临的问题与该事实无关)

# setup
import numpy as np

random_state = 0

# create mockup data
train_labels = np.random.randint(2, size=(1000))
validation_labels = np.random.randint(2, size=(200))
train_features = np.random.rand(1000, 49)
validation_features = np.random.rand(200, 49)

# create and train model
from keras.models import Sequential
from keras.layers import Dense, Dropout

dim = train_features.shape[1]

model = Sequential()
model.add(Dense(dim, input_dim=dim, activation='relu'))
model.add(Dropout(0.2, seed=random_state))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics='accuracy')
model.summary()

epochs = 10
batch_size = 100

model.fit(train_features, train_labels, epochs=epochs, batch_size=batch_size, verbose=1, validation_data=(validation_features, validation_labels))

# evaluate model
from sklearn.metrics import confusion_matrix

prediction_labels = model.predict_classes(validation_features)
print(confusion_matrix(prediction_labels, validation_labels))

# save model
model.save('/kaggle/working/model')

然后在新的会话中加载模型。当我在单个CPU上进行预测时,它就可以正常工作。但是,我实际上想做的是在所有4个可用CPU上并行进行预测(与我的模型相反,我要解决的实际问题要复杂得多,涉及的数据更多)。我试图像这样使用multiprocessing来做到这一点:

# setup
import numpy as np
import pandas as pd

random_state = 0

# load model
from keras.models import load_model

model = load_model('../input/repro-model/model')

# create mockup test data
test_features = np.random.rand(1000, 49)

# make some test predictions on the entire test set and a single observation
model.predict_classes(test_features)
model.predict_classes(test_features[[0]])

# set up test dataframe for making predictions in parallel
test_df = pd.DataFrame(test_features)

seq_ids = []
for i in np.arange(1,201):
    seq_id = [i] * 5
    seq_ids.append(seq_id)

frm_ids = [np.arange(1,6)] * 200

test_df['seq_id'] = [item for sublist in seq_ids for item in sublist]
test_df['frm_id'] = [item for sublist in frm_ids for item in sublist]

# test the setup
seq_id = 1
frm_id = 1

model.predict_classes(test_df[(test_df.seq_id == seq_id) & (test_df.frm_id == frm_id)].drop(['seq_id', 'frm_id'], axis=1))

# create function for making predictions
def make_prediction(model, data, seq_id, frm_id):
    
    print(seq_id)
    
    pred = model.predict_classes(data[(data.seq_id == seq_id) & (data.frm_id == frm_id)].drop(['seq_id', 'frm_id'], axis=1))

    return pred

# make test prediction
make_prediction(model, test_df, 1, 1)

# make predictions in parallel
from multiprocessing import Pool
import itertools

workers = 4
p = Pool(processes=workers)

seq_list, frm_list = np.arange(1, 201), np.arange(1, 6)
id_pair_list = list(itertools.product(seq_list, frm_list))

predictions = p.starmap(make_prediction, [(model, test_df, id_pair[0], id_pair[1]) for id_pair in id_pair_list])
p.close()

运行上面的代码时,我得到以下堆栈跟踪和错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-85-52e9e388daba> in <module>
      8 id_pair_list = list(itertools.product(seq_list, frm_list))
      9 
---> 10 predictions = p.starmap(make_prediction, [(model, test_df, id_pair[0], id_pair[1]) for id_pair in id_pair_list])
     11 p.close()

/opt/conda/lib/python3.7/multiprocessing/pool.py in starmap(self, func, iterable, chunksize)
    274         `func` and (a, b) becomes func(a, b).
    275         '''
--> 276         return self._map_async(func, iterable, starmapstar, chunksize).get()
    277 
    278     def starmap_async(self, func, iterable, chunksize=None, callback=None,

/opt/conda/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
    655             return self._value
    656         else:
--> 657             raise self._value
    658 
    659     def _set(self, i, obj):

/opt/conda/lib/python3.7/multiprocessing/pool.py in _handle_tasks(taskqueue, put, outqueue, pool, cache)
    429                         break
    430                     try:
--> 431                         put(task)
    432                     except Exception as e:
    433                         job, idx = task[:2]

/opt/conda/lib/python3.7/multiprocessing/connection.py in send(self, obj)
    204         self._check_closed()
    205         self._check_writable()
--> 206         self._send_bytes(_ForkingPickler.dumps(obj))
    207 
    208     def recv_bytes(self, maxlength=None):

/opt/conda/lib/python3.7/multiprocessing/reduction.py in dumps(cls, obj, protocol)
     49     def dumps(cls, obj, protocol=None):
     50         buf = io.BytesIO()
---> 51         cls(buf, protocol).dump(obj)
     52         return buf.getbuffer()
     53 

TypeError: can't pickle _thread.RLock objects

阅读完有关此主题的所有内容后,我在SO(例如hereherehere)和GitHub(例如here ,以及herehere),我发现腌制顺序Keras模型和多线程存在一些问题。然后,我尝试将模型加载到预测函数中,并在每次预测函数调用后将其删除,如下所示:

def make_prediction(data, seq_id, frm_id):
    
    print(seq_id)
    
    from keras.models import load_model
    model = load_model('../input/repro-model/model')
    
    pred = model.predict_classes(data[(data.seq_id == seq_id) & (data.frm_id == frm_id)].drop(['seq_id', 'frm_id'], axis=1))
    
    del model
    
    return pred

from multiprocessing import Pool
import itertools

workers = 4
p = Pool(processes=workers)

seq_list, frm_list = np.arange(1, 201), np.arange(1, 6)
id_pair_list = list(itertools.product(seq_list, frm_list))

predictions = p.starmap(make_prediction, [(test_df, id_pair[0], id_pair[1]) for id_pair in id_pair_list])
p.close()

我再也没有收到错误,但是在进行前四个预测时执行starmap函数调用时,该过程挂起。因此,肯定还有一些问题。使用model._make_predict_function()似乎无效,here无效。关于此主题的关于SO的文章也很多,但没有得到解答:deprecatedherehere

有人知道我如何将连续的Keras模型与CPU并行进行预测吗?这将是很棒的,因为我已经在这个问题上坚持了很长时间了。非常感谢!

1 个答案:

答案 0 :(得分:0)

添加此代码段可帮助我解决此问题(但仅适用于tf 2.0.0版)

<div class="container">
  Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor
  in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. lorem. lorem Lorem ipsum dolor sit amet, consectetur adipisicing
  elit, sed do lorem ipsum
  <div class="inner"></div>
</div>