我对Tensorflow的新输入管道机制有疑问。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像然后将它们加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,那么内存使用量会无限增加。
以下是使用tf.data.Dataset
构建管道的代码段def _imread(file_name, label):
_raw = tf.read_file(file_name)
_decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
_resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
_scaled = (_resized / 127.5) - 1.0
return _scaled, label
n_samples = image_files.shape.as_list()[0]
dset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dset = dset.shuffle(n_samples, None)
dset = dset.repeat(hps.n_epochs)
dset = dset.map(_imread, hps.batch_size * 32)
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)
此处image_files
是一个常数张量,包含30k图像的文件名。图像在_imread中调整为256x256x3。
如果使用以下代码段构建管道:
# refer to "https://www.tensorflow.org/programmers_guide/datasets"
def _imread(file_name, hps):
_raw = tf.read_file(file_name)
_decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
_resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
_scaled = (_resized / 127.5) - 1.0
return _scaled
n_samples = image_files.shape.as_list()[0]
image_file, label = tf.train.slice_input_producer(
[image_files, labels],
num_epochs=hps.n_epochs,
shuffle=True,
seed=None,
capacity=n_samples,
)
# Decode image.
image = _imread(image_file,
images, labels = tf.train.shuffle_batch(
tensors=[image, label],
batch_size=hps.batch_size,
capacity=hps.batch_size * 64,
min_after_dequeue=hps.batch_size * 8,
num_threads=32,
seed=None,
enqueue_many=False,
allow_smaller_final_batch=True
)
然后整个训练过程中内存使用率几乎不变。如何使tf.data.Dataset加载固定数量的样本?我用tf.data.Dataset创建的管道是否正确?我认为tf.data.Dataset.shuffle中的buffer_size
参数适用于image_files
和labels
。所以存储30k字符串应该不是问题,对吧?即使要加载30k图像,也需要30000*256*256*3*8/(1024*1024*1024)=
43GB的内存。但它使用59GB的61GB系统内存。
答案 0 :(得分:1)
这将缓冲n_samples,它们看起来就像是你的整个数据集。你可能想在这里减少缓冲。
dset = dset.shuffle(n_samples, None)
你也可以永远重复,重复赢得缓冲(Does `tf.data.Dataset.repeat()` buffer the entire dataset in memory?)
dset = dset.repeat()
您正在批处理,然后预取hps.batch_size批次。哎哟!
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)
让hps.batch_size = 1000说一个具体的例子。上面的第一行创建了一批1000张图像。上面的第2行创建了2000个批次,每1000个图像,缓冲总计2,000,000个图像。糟糕!
你的意思是:
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)