为什么此数据集实现会用尽内存?

时间:2019-07-05 03:32:09

标签: python tensorflow tensorflow-datasets

我遵循this instruction并编写以下代码来为图像创建数据集(COCO2014 training set

from pathlib import Path
import tensorflow as tf


def image_dataset(filepath, image_size, batch_size, norm=True):
    def preprocess_image(image):
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, image_size)
        if norm:
            image /= 255.0  # normalize to [0,1] range
        return image

    def load_and_preprocess_image(path):
        image = tf.read_file(path)
        return preprocess_image(image)

    all_image_paths = [str(f) for f in Path(filepath).glob('*')]
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
    ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.shuffle(buffer_size = len(all_image_paths))
    ds = ds.repeat()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

ds = image_dataset(train2014_dir, (256, 256), 4, False)
image = ds.make_one_shot_iterator().get_next('images')
# image is then fed to the network

此代码将始终耗尽内存(32G)和GPU(11G)并终止进程。这是终端上显示的消息。 enter image description here

我还发现该程序卡在sess.run(opt_op)上。哪里错了?我该如何解决?

1 个答案:

答案 0 :(得分:2)

问题是这样的:

ds = ds.shuffle(buffer_size = len(all_image_paths))

Dataset.shuffle()使用的缓冲区是“内存中”缓冲区,因此您正在有效地尝试将整个数据集加载到内存中。

您有几个选项(可以组合使用)来解决此问题:

选项1:

将缓冲区大小减小到更小的数字。

选项2:

shuffle()语句移至map()语句之前。

这意味着我们将在加载图像之前之前进行改组,因此我们只是将文件名存储在改组的内存缓冲区中,而不是存储庞大的张量。