我想加快使用Estimator API和使用tf.data.Dataset
编写的input_fn的训练例程。
我的实现需要2秒钟来准备一批数据,然后在GPU上运行训练1秒,然后重新开始准备批处理。这真的很低效。
我正在寻找一种方法来异步准备批次并将其上传到GPU以加速培训。或者,对于在input_fn
的调用之间缓存数据集的方法(dataset.cache()
似乎不是一个好的选择,因为必须在每个input_fn调用上重新创建数据集。)
以下是我的代码的简化版本:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我注意到Estimator API正在积极开发中,在tensorflow的主分支中,input_fn已经可以返回数据集,所以也许我要求得太早,而且这个功能还没有准备好。但如果是这样,请提供可以跟踪此实施的票据。
答案 0 :(得分:6)
使用tf.data.Dataset.cache()
确实不是一个好选择,因为它会将整个数据集缓存到内存中,这需要时间并且可能会溢出内存。
要做的是在管道的末尾使用tf.data.Dataset.prefetch()
,这将始终确保数据管道包含buffer_size
个元素。通常在最后有buffer_size = 1
就足够了:
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1) # prefetch one batch
正如this answer中@mrry所解释的那样,您也可以尝试增加预取批次的数量。
通常,在管道的最末端添加一个小的预取缓冲区(可能只有一个元素)是最有用的,但更复杂的管道可以从额外的预取中受益,特别是当生成单个元素的时间可以而变化。
如果与GPU计算相比,输入管道仍然较慢,则需要使用tf.data.Dataset.map()
的num_parallel_calls
参数增加并行工作的线程数。
答案 1 :(得分:1)
要添加Olivier的答案,主要来自this post:
repeat
之前的shuffle
稍微快一些,在模糊的时代边界的下方。这在极少数情况下可能很重要,但我对此表示怀疑。shuffle
ping之前map
- 这会减少shuffle缓冲区大小的内存占用量,因为它只需要缓冲文件名而不是文件内容。get_next()
的输出而不是数据集更有意义 - 不确定这是否会影响速度。您还可以考虑将其他两个地图调用放在同一个地图中,以减少调度问题。repeat
之前使用batch
进行实验。可能不会有所作为,但可能是次要的。如果您repeat
之前shuffle
如上所述,则必须这样做。prefetch
。包含修改的代码:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.repeat(epochs)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
def combined_map_fn(*args):
return _post_process(_read_wav(*args))
dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)
iterator = dataset.dataset.make_one_shot_iterator()
wavs, labels = iterator.get_next()
features = {'wav': wavs}
return features, labels