Tensorflow seq2seq重量分配

时间:2015-12-21 14:45:35

标签: sequence tensorflow

def rnn_seq2seq(encoder_inputs, decoder_inputs, cell, output_projection=None,feed_previous=False, dtype=tf.float32, scope=None):
  with tf.variable_scope(scope or "rnn_seq2seq"):
    _, enc_states = rnn.rnn(cell, encoder_inputs,dtype=dtype)


  def extract_argmax(prev, i):
    if output_projection is not None:
        prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
    return tf.to_float(tf.equal(prev,tf.reduce_max(prev,reduction_indices=[1],keep_dims=True)))

  loop_function = None
  if feed_previous:
    loop_function = extract_argmax

       #seq2seq.rnn_decoder is provided in tensorflow/models/rnn/seq2seq.py
  return seq2seq.rnn_decoder(decoder_inputs, enc_states[-1], cell, loop_function=loop_function)

我想创建两个RNN模型,一个用于培训,另一个用于测试。为此,我可以将函数调用两次,将feed_previous传递给True或False。

train_op,train_states = rnn_seq2seq(enc_inp,dec_inp,cell,output_projection=op,feed_previous=False)
test_op,_ = rnn_seq2seq(enc_inp,dec_inp,cell,output_projection=op,feed_previous=True)

但如果我将上述功能调用两次,是否会创建两个不同的RNN?我想知道他们是否能分享权重。

1 个答案:

答案 0 :(得分:1)

两个函数都在相同的默认图形上运行,因此可以重复使用变量,查看variable scopes tutorial并查看您的变量是否使用reuse=True参数创建

作为完整性检查,请尝试使用以下代码段列出默认图表中的所有变量:

[v.name for v in tf.get_default_graph().as_graph_def().node if v.op=='Variable']