为MNIST数据集运行自定义Tensorflow训练循环时出现OOM错误

时间:2019-10-15 15:49:06

标签: python tensorflow

我编写了一个定制的Tensorflow训练循环来训练MNIST分类器。 我遇到了错误:

OOM when allocating tensor for MNIST

The snap shot of error

这是我的代码:https://github.com/soon22/learningTensorflowCustomTrainingLoop/blob/master/mnist_custom_training_loop.ipynb

使用tensorflow.keras的{​​{1}}和model.compile,训练成功了,准确率超过90%。它没有这个问题。 我做错了什么。

1 个答案:

答案 0 :(得分:0)

您正在尝试让模型同时针对MNIST数据集中的所有60000个数据点进行预测,并为所得损失计算梯度。对于您的图形卡来说,这太多了。

尝试训练说几百个数据点。 model.fit没有给出OOM的原因是,如果您未为batch_size指定另一个值(请参见https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit),则model.fit的默认批次大小为32。