我目前正在尝试了解Tensorflow的LSTM教程,并对https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/ptb/ptb_word_lm.py的代码提出疑问。
在函数run_epoch()
中,这些行运行一个纪元,input.epoch_size
实际上是数据的批次数:
for step in range(model.input.epoch_size):
feed_dict = {}
for i, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[i].c
feed_dict[h] = state[i].h
vals = session.run(fetches, feed_dict)
cost = vals["cost"]
state = vals["final_state"]
costs += cost
iters += model.input.num_steps
if verbose and step % (model.input.epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" %
(step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
iters * model.input.batch_size / (time.time() - start_time)))
但是我想知道,这段代码怎么说"告诉"我们的时代的LSTM模型?在init方法中的LSTM类中,加载了整个数据,并且通常在数据上定义计算。
我的第二个问题是c
和h
的计算。我们为什么这样做?它与有状态与无状态LSTM有关吗?那么我可以安全地为香草LSTM删除该代码吗?
谢谢!
答案 0 :(得分:4)
如果您在同一个文件中看到line 348,则代码会为每个纪元调用一次run_epoch()
。每个时期,LSTM单元在训练进行时被初始化为全零状态。提出你的问题,
但我想知道,这段代码如何“告诉”LSTM模型我们在哪个时代呢?在init方法中的LSTM类中,加载了整个数据,并且通常在数据上定义计算。
正在更新LSTM单元内的权重,并且在每个时期的开始使用LSTM的initial_state
。没有必要明确地告诉LSTM时代号。
我的第二个问题是c和h的计算。我们为什么这样做?它与有状态与无状态LSTM有关吗?那么我可以安全地为香草LSTM删除该代码吗?
这是非常重要的一步。这样做是为了在不同批次中传递LSTM状态。 LSTM有两个内部状态c
和h
。当这些被送入图表时,前一批次的最终状态成为下一批次的初始状态。您可以通过计算model.final_state
并在下次feed_dict
中传递它来替换它。如果您查看TensorFlow代码,state
在c
为h
时基本上是state_is_tuple
和True
的元组,因为您可以阅读here }。
答案 1 :(得分:0)
# for i, (c, h) in enumerate(model.initial_state):
# feed_dict[c] = state[i].c
# feed_dict[h] = state[i].h
feed_dict[model._initial_state]=state;
For循环是初始化当前批次的第一个单元格状态,具有最后一批的最终单元格状态。