如何在tensorflow2.1.0中加载大型numpy数组数据集以训练CNN模型

时间:2020-05-17 22:57:34

标签: numpy tensorflow deep-learning

我正在训练卷积神经网络(CNN)模型以在Tensorflow2.1.0中进行二进制分类任务。 每个实例的特征是形状为(50,50,50,2)的4维numpy数组,其中每个元素的类型为float32。 每个实例的标签为1或0 我最大的训练数据集可以包含多达1亿个实例。

要有效地训练模型,最好将我的训练数据序列化并以TFrecord格式存储在一组文件中,然后使用tf.data.TFRecordDataset()加载它们并使用tf.data.map对其进行解析()? 如果是这样,您能给我一个例子,说明如何序列化功能标签对并将它们存储到TFrecord文件中,然后如何加载和解析它们? 我在Tensorflow网站上找不到合适的示例。

还是有更好的方法来存储和加载庞大的数据集?非常感谢。

1 个答案:

答案 0 :(得分:0)

有很多方法可以在没有TFRecord的情况下有效地建立数据管道,请点击此link it was very useful

要有效地从目录中提取图像,然后单击此link

希望这对您有所帮助。