如何在tensorflow 2.x上训练大型数据集

时间:2020-06-29 06:36:30

标签: tensorflow tensorflow2.x

我有一个大型数据集,包含约200万行和6,000列。输入的numpy数组(X,y)可以保存训练数据。但是当涉及到model.fit()时,出现了GPU内存不足错误。我正在使用tensorflow 2.2。根据其手册,不推荐使用model.fit_generator,而更推荐使用model.fit。

有人可以概述使用tensorflow v2.2训练大型数据集的步骤吗?

1 个答案:

答案 0 :(得分:1)

最好的解决方案是使用tf.data.Dataset(),因此您可以使用.batch()方法轻松批处理数据。

这里有很多教程,您可能想使用from_tensor_slices()直接玩numpy数组。

下面有两个非常适合您需要的文档。

https://www.tensorflow.org/tutorials/load_data/numpy

https://www.tensorflow.org/guide/data