tf model.fit()中的batch_size与tf.data.Dataset中的batch_size

时间:2020-07-01 05:04:38

标签: tensorflow tensorflow2.0 tensorflow2.x

我有一个可以容纳主机内存的大型数据集。但是,当我使用tf.keras训练模型时,会产生GPU内存不足的问题。然后,我研究tf.data.Dataset并想使用其batch()方法批处理训练数据集,以便它可以在GPU中执行model.fit()。根据其文档,示例如下:

height constraint

dataset.from_tensor_slices()。batch()中的BATCH_SIZE是否与tf.keras modelt.fit()中的batch_size相同?

我应该如何选择BATCH_SIZE,以便GPU具有足够的数据来有效运行,而其内存不会溢出?

1 个答案:

答案 0 :(得分:2)

在这种情况下,您不需要在batch_size中传递model.fit()参数。它将自动使用您在tf.data.Dataset().batch()中使用的BATCH_SIZE。

关于您的其他问题:确实需要仔细调整批大小超参数。另一方面,如果看到OOM错误,则应减少该错误,直到没有OOM为止(通常以这种方式32-> 16-> 8 ...)。

在您的情况下,我将从2的batch_size开始,将其增加2的幂,然后检查是否仍然获得OOM。

如果使用batch_size方法,则无需提供tf.data.Dataset().batch()参数。

事实上,甚至官员documentation都这样指出:

batch_size:整数或无。每个梯度更新的样本数。 如果未指定,batch_size将默认为32。请勿指定 batch_size(如果您的数据是数据集,生成器或 keras.utils.Sequence实例(因为它们生成批次)。