tf.data.Dataset

时间:2017-11-12 16:41:56

标签: tensorflow

我对Tensorflow的新输入管道机制有疑问。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像然后将它们加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,那么内存使用量会无限增加。

以下是使用tf.data.Dataset

构建管道的代码段
def _imread(file_name, label):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled, label

n_samples = image_files.shape.as_list()[0]
dset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dset = dset.shuffle(n_samples, None)
dset = dset.repeat(hps.n_epochs)
dset = dset.map(_imread, hps.batch_size * 32)
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)

此处image_files是一个常数张量,包含30k图像的文件名。图像在_imread中调整为256x256x3。

如果使用以下代码段构建管道:

# refer to "https://www.tensorflow.org/programmers_guide/datasets"
def _imread(file_name, hps):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled

n_samples = image_files.shape.as_list()[0]

image_file, label = tf.train.slice_input_producer(
  [image_files, labels],
  num_epochs=hps.n_epochs,
  shuffle=True,
  seed=None,
  capacity=n_samples,
)

# Decode image.
image = _imread(image_file, 

images, labels = tf.train.shuffle_batch(
  tensors=[image, label],
  batch_size=hps.batch_size,
  capacity=hps.batch_size * 64,
  min_after_dequeue=hps.batch_size * 8,
  num_threads=32,
  seed=None,
  enqueue_many=False,
  allow_smaller_final_batch=True
)

然后整个训练过程中内存使用率几乎不变。如何使tf.data.Dataset加载固定数量的样本?我用tf.data.Dataset创建的管道是否正确?我认为tf.data.Dataset.shuffle中的buffer_size参数适用于image_fileslabels。所以存储30k字符串应该不是问题,对吧?即使要加载30k图像,也需要30000*256*256*3*8/(1024*1024*1024)= 43GB的内存。但它使用59GB的61GB系统内存。

1 个答案:

答案 0 :(得分:1)

这将缓冲n_samples,它们看起来就像是你的整个数据集。你可能想在这里减少缓冲。

dset = dset.shuffle(n_samples, None)

你也可以永远重复,重复赢得缓冲(Does `tf.data.Dataset.repeat()` buffer the entire dataset in memory?

dset = dset.repeat()

您正在批处理,然后预取hps.batch_size批次。哎哟!

dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)

让hps.batch_size = 1000说一个具体的例子。上面的第一行创建了一批1000张图像。上面的第2行创建了2000个批次,每1000个图像,缓冲总计2,000,000个图像。糟糕!

你的意思是:

dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)