使用Estimators API结合tf.data.Dataset时如何加快批量准备

时间:2018-01-02 21:22:32

标签: tensorflow tensorflow-datasets tensorflow-estimator

我想加快使用Estimator API和使用tf.data.Dataset编写的input_fn的训练例程。

我的实现需要2秒钟来准备一批数据,然后在GPU上运行训练1秒,然后重新开始准备批处理。这真的很低效。

我正在寻找一种方法来异步准备批次并将其上传到GPU以加速培训。或者,对于在input_fn的调用之间缓存数据集的方法(dataset.cache()似乎不是一个好的选择,因为必须在每个input_fn调用上重新创建数据集。)

以下是我的代码的简化版本:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我注意到Estimator API正在积极开发中,在tensorflow的主分支中,input_fn已经可以返回数据集,所以也许我要求得太早,而且这个功能还没有准备好。但如果是这样,请提供可以跟踪此实施的票据。

2 个答案:

答案 0 :(得分:6)

使用tf.data.Dataset.cache()确实不是一个好选择,因为它会将整个数据集缓存到内存中,这需要时间并且可能会溢出内存。

要做的是在管道的末尾使用tf.data.Dataset.prefetch(),这将始终确保数据管道包含buffer_size个元素。通常在最后有buffer_size = 1就足够了:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

正如this answer中@mrry所解释的那样,您也可以尝试增加预取批次的数量。

  

通常,在管道的最末端添加一个小的预取缓冲区(可能只有一个元素)是最有用的,但更复杂的管道可以从额外的预取中受益,特别是当生成单个元素的时间可以而变化。

如果与GPU计算相比,输入管道仍然较慢,则需要使用tf.data.Dataset.map()num_parallel_calls参数增加并行工作的线程数。

答案 1 :(得分:1)

要添加Olivier的答案,主要来自this post

    repeat之前的
  • shuffle稍微快一些,在模糊的时代边界的下方。这在极少数情况下可能很重要,但我对此表示怀疑。
  • shuffle ping之前
  • map - 这会减少shuffle缓冲区大小的内存占用量,因为它只需要缓冲文件名而不是文件内容。
  • 将第三个地图变换应用于get_next()的输出而不是数据集更有意义 - 不确定这是否会影响速度。您还可以考虑将其他两个地图调用放在同一个地图中,以减少调度问题。
  • repeat之前使用batch进行实验。可能不会有所作为,但可能是次要的。如果您repeat之前shuffle如上所述,则必须这样做。
  • 如Olivier所述,使用prefetch

包含修改的代码:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels