我有一个可以容纳主机内存的大型数据集。但是,当我使用tf.keras训练模型时,会产生GPU内存不足的问题。然后,我研究tf.data.Dataset并想使用其batch()方法批处理训练数据集,以便它可以在GPU中执行model.fit()。根据其文档,示例如下:
height constraint
dataset.from_tensor_slices()。batch()中的BATCH_SIZE是否与tf.keras modelt.fit()中的batch_size相同?
我应该如何选择BATCH_SIZE,以便GPU具有足够的数据来有效运行,而其内存不会溢出?
答案 0 :(得分:2)
在这种情况下,您不需要在batch_size
中传递model.fit()
参数。它将自动使用您在tf.data.Dataset().batch()
中使用的BATCH_SIZE。
关于您的其他问题:确实需要仔细调整批大小超参数。另一方面,如果看到OOM错误,则应减少该错误,直到没有OOM为止(通常以这种方式32-> 16-> 8 ...)。
在您的情况下,我将从2的batch_size开始,将其增加2的幂,然后检查是否仍然获得OOM。
如果使用batch_size
方法,则无需提供tf.data.Dataset().batch()
参数。
事实上,甚至官员documentation都这样指出:
batch_size:整数或无。每个梯度更新的样本数。 如果未指定,batch_size将默认为32。请勿指定 batch_size(如果您的数据是数据集,生成器或 keras.utils.Sequence实例(因为它们生成批次)。