大数据集上的predict_on_batch()中的内存错误

时间:2018-10-04 08:54:25

标签: python keras deep-learning

我有18000个示例的测试集。

Χ_test.shape: (18000, 128, 128, 1)

我已经训练好模型,并希望在X_test上使用预测。

如果我尝试仅使用:

pred = model.predict_on_batch(X_test)

出现内存错误。

我尝试过类似的事情:

X_test_split = X_test.flatten()
X_test_split = np.array_split(X_test_split, 562) # batch size is 32
pred = np.empty(len(X_test_split), dtype=np.float32)

for idx, _ in enumerate(X_test_split):
    pred[idx] = model.predict_on_batch(X_test_split[idx].reshape(32, 128, 128, 1))

但是它要么再次给我带来内存错误,要么给我关于重整的错误(取决于我在上面的代码中尝试的变化)

我也使用predict_generator遇到相同的问题。

1 个答案:

答案 0 :(得分:1)

根据OP的要求,我将发表我的评论作为答案,并尝试详细说明:

您的模型似乎很大,因此您需要使用较小的批处理大小(<32,因为您提到它不适用于32)或修改模型并减少参数数量(例如,删除一些图层) ,减少过滤器或单元的数量等。