如何在Tensorflow的PTB LSTM示例中迭代批处理?

时间:2016-12-06 09:32:16

标签: tensorflow

我目前正在尝试了解Tensorflow的LSTM教程,并对https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/ptb/ptb_word_lm.py的代码提出疑问。 在函数run_epoch()中,这些行运行一个纪元,input.epoch_size实际上是数据的批次数:

for step in range(model.input.epoch_size):
    feed_dict = {}
    for i, (c, h) in enumerate(model.initial_state):
      feed_dict[c] = state[i].c
      feed_dict[h] = state[i].h

    vals = session.run(fetches, feed_dict)
    cost = vals["cost"]
    state = vals["final_state"]

    costs += cost
    iters += model.input.num_steps

    if verbose and step % (model.input.epoch_size // 10) == 10:
      print("%.3f perplexity: %.3f speed: %.0f wps" %
            (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
             iters * model.input.batch_size / (time.time() - start_time)))

但是我想知道,这段代码怎么说"告诉"我们的时代的LSTM模型?在init方法中的LSTM类中,加载了整个数据,并且通常在数据上定义计算。

我的第二个问题是ch的计算。我们为什么这样做?它与有状态与无状态LSTM有关吗?那么我可以安全地为香草LSTM删除该代码吗?

谢谢!

2 个答案:

答案 0 :(得分:4)

如果您在同一个文件中看到line 348,则代码会为每个纪元调用一次run_epoch()。每个时期,LSTM单元在训练进行时被初始化为全零状态。提出你的问题,

  

但我想知道,这段代码如何“告诉”LSTM模型我们在哪个时代呢?在init方法中的LSTM类中,加载了整个数据,并且通常在数据上定义计算。

正在更新LSTM单元内的权重,并且在每个时期的开始使用LSTM的initial_state。没有必要明确地告诉LSTM时代号。

  

我的第二个问题是c和h的计算。我们为什么这样做?它与有状态与无状态LSTM有关吗?那么我可以安全地为香草LSTM删除该代码吗?

这是非常重要的一步。这样做是为了在不同批次中传递LSTM状态。 LSTM有两个内部状态ch。当这些被送入图表时,前一批次的最终状态成为下一批次的初始状态。您可以通过计算model.final_state并在下次feed_dict中传递它来替换它。如果您查看TensorFlow代码,statech时基本上是state_is_tupleTrue的元组,因为您可以阅读here }。

答案 1 :(得分:0)

# for i, (c, h) in enumerate(model.initial_state):
#   feed_dict[c] = state[i].c
#   feed_dict[h] = state[i].h

feed_dict[model._initial_state]=state;

For循环是初始化当前批次的第一个单元格状态,具有最后一批的最终单元格状态。