我有一个大型数据集,包含约200万行和6,000列。输入的numpy数组(X,y)可以保存训练数据。但是当涉及到model.fit()时,出现了GPU内存不足错误。我正在使用tensorflow 2.2。根据其手册,不推荐使用model.fit_generator,而更推荐使用model.fit。
有人可以概述使用tensorflow v2.2训练大型数据集的步骤吗?
答案 0 :(得分:1)
最好的解决方案是使用tf.data.Dataset()
,因此您可以使用.batch()
方法轻松批处理数据。
这里有很多教程,您可能想使用from_tensor_slices()
直接玩numpy
数组。
下面有两个非常适合您需要的文档。