为什么在每个时代开始时培训损失会增加?

时间:2017-10-25 09:08:27

标签: machine-learning tensorflow tensorflow-datasets

我正在训练线性回归模型。我使用tf.contrib.data来准备数据集,将其洗牌并分批投放:

  dataset = tf.contrib.data.TFRecordDataset(filename)
  dataset = dataset.map(
      _parse_function, num_threads=16, output_buffer_size=100 * batch_size)
  dataset = dataset.repeat(5)
  dataset = dataset.shuffle(buffer_size=100000)
  dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]))
  iterator = dataset.make_initializable_iterator()
  x_inputs, y_ = iterator.get_next()

以下是我们的培训损失: training loss

很奇怪,在每个时期的开始(迭代= 100k),我们在训练损失中有一个脉冲。如果训练过程继续进行,我们会在下一个时期开始时看到相同的模式。

3 个答案:

答案 0 :(得分:2)

假设您的数据集少于100000条记录,则问题可能是输入随机播放中的随机性不足。直观地说,如果现有数据没有被洗牌并且它们的顺序有一些结构,那么训练过程可能会过度填充到文件末尾的记录,当你在开始时重新启动时,模型将无法在记录上执行。在文件开头附近,损失会增加。

正确的解决方案将取决于数据集的精确细节,但以下某些方法可能有用:

  1. 如果可能,请将传递给buffer_size的{​​{1}}增加到与数据集一样大的值。 (如果您的整个数据集不适合内存,则可能无法实现。)

  2. 通过在训练前随机化输入文件的顺序,确保数据在记录顺序中没有任何结构。

  3. 除了随机化文件中的顺序外,您还可以将数据分区为多个不同的文件,并使用以下内容随机化访问这些文件的顺序:

    Dataset.shuffle()
  4. 作为(3)的扩展,您可以使用dataset = tf.data.Dataset.list_files(shard_filenames) dataset = dataset.shuffle(len(shard_filenames)) dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename)) # ... 代替Dataset.interleave()一次读取多个分片。

答案 1 :(得分:0)

在这里可以看到相同的锯齿形:https://discuss.pytorch.org/t/strange-behavior-with-sgd-momentum-training/7442

建议的解决方案是确保您的数据加载器加载替换后的数据。

答案 2 :(得分:0)

谢谢。

当我不随机播放示例或输入文件时,我遇到了相同的模式:

enter image description here

按照您的建议,我分别整理了示例 tfrecord文件,但是当新纪元开始时,仍然有些奇怪的事情(请参见下图) )。

  • 示例数:80M
  • 批处理大小:256
  • 一个时期的批次数量:312,500
  • 初始学习率是10 ^ -3,在50万批之后,我将其更改为10 ^ -4
  • 让我知道您是否需要有关培训或数据的更多详细信息(我已尝试使其保持简单)。

火车(橙色)和测试(蓝色)损耗绘制在下图中。

我们可以看到,每个时期之后,火车损失都有所下降:第一个时期在312.5K批次后完成,第二个时期在625K批次后完成,第三个时期在937.5K之后批次。

我想这与我向TensorBoard报告的方式有关,但我很乐意在这里获得您的建议。橙色图中的每个点(火车损失)是最近5K批次损失的平均值。 TensorBoard平滑设置为0。

我用于TensorBoard报告的代码是:

summary = tf.Summary(value=[tf.Summary.Value(tag='avg loss', simple_value=avg_loss)])
summary_writer.add_summary(summary, current_global_step)

您以前遇到过这种模式吗?

enter image description here