神经网络实现中的内存不足(使用Numpy数组)

时间:2019-03-19 18:09:00

标签: python numpy tensorflow keras deep-learning

我的数据集的格式如下:

培训数据

大小为(7855、448、448、3)的numpy数组,其中(448、448、3)是RGB图像的numpy版本。因为网络的目的是回归,所以我还没有找到使用ImageDataGenerator的解决方案。因此,我将整个图像数据集转换为一个numpy数组。

培训目标

训练目标是大小为7855的一维numpy数组。条目对应于训练数据。

要获得numpy数组,我必须将整个数据集加载到内存中的变量中,然后将其传递以进行拟合和预测。仅此一项就需要多达5到6个RAM。

拟合模型时,RAM迅速溢出,并且运行时崩溃。如何分批填充numpy数组元素,或者是否有另一种加载格式为的数据集的方法:

|list of images |
|labelled       |
|1, 2, 3...     |
|n              |


|csv file with: |
|1   target1    |
|2   target2    |
|3   target3... |

CODE https://colab.research.google.com/drive/1FUvPcpYiDtli6vwIaTwacL48RwZ0sq-9

[我一直在使用Google Colab,因为这是一个学术研究项目,还没有投资高端服务器。 ]

1 个答案:

答案 0 :(得分:0)

您需要使用数据集API。 创建numpy数组,train_images和train_target时,请使用tf.data.Dataset.from_tensor_slices

dataset = tf.data.Dataset.from_tensor_slices((train_images, train_target))

这将创建数据集对象,可以将其输入model.fit 您可以将任何解析函数混洗,批处理并映射到该数据集。您可以控制将随机播放缓冲区预装入多少个示例。重复控件的纪元计数,最好保留None,这样它将无限期重复。<​​/ p>

dataset = dataset.shuffle().repeat()
dataset = dataset.batch()

请记住,批处理在此管道内进行,因此您不需要在model.fit中使用批处理,但是您需要传递历元数和每个历元的步骤。后者可能会有些棘手,因为您无法像len(dataset)那样做,因此应提前计算。

model.fit(dataset, epochs, steps_per_epoch)

如果您遇到graphdef限制错误,最好保存几个较小的numpy数组并将它们作为列表传递

使自己熟悉这个方法 https://www.tensorflow.org/guide/datasets 希望这会有所帮助。