tensorflow复制rnn中的变量

时间:2016-02-26 21:37:19

标签: tensorflow

与此相关:How can I copy a variable in tensorflow

我正在尝试复制lstm解码单元的值,以便在其他地方使用它来进行beamsearch。在伪代码中,我想要这样的东西:

lstm_decode = tf.nn.rnn_cell(...)
training_output = tf.nn.seq2seq.rnn_decoder(...)
... do training by back-prop the error on trainint_output ...

# duplicate the lstm_decode unit (same weights)
lstm_decode_copy = copy(lstm_decode)
... do beam search with the duplicated lstm ...

问题是在tensorflow中,在调用“tf.nn.rnn_cell(...)”期间不会生成lstm变量,但它实际上是在向rnn_decoder的函数调用展开期间生成的。

我可以将范围设置为“tf.nn.seq2seq.rnn_decoder”函数调用,但是lstm权重的实际初始化对我来说是不透明的。我如何捕获这些值并重新使用它们来创建一个与所学习的权重相同的lstm单元?

谢谢!

1 个答案:

答案 0 :(得分:0)

我想我已经明白了。

你想要的是在这一行中将解码器调用的范围设置为特定值,比如“解码”:

training_output = tf.nn.seq2seq.rnn_decoder(...scope="decoding")

稍后当您想要使用在解码期间学习的lstm单位时,将变量范围再次设置为“解码”,并使用scope.reuse_variables()来允许重复使用解码变量。然后只需使用“lstm_decode”,否则将绑定到与之前相同的值。

with tf.variable_scope("decoding") as scope:
  scope.reuse_variables()
  ... use lstm_decode as usual ...

这样,lstm_decode中的所有权重都将在这两个子图中共享,并且在训练期间学到的任何值也将出现在第二部分中。