如何在新的编码器 - 解码器模型中保存tensorflow dynamic_rnn模型并将其恢复为解码器?

时间:2017-06-06 04:26:08

标签: tensorflow save restore

我正在尝试训练编码器 - 解码器模型以自动生成摘要。编码器部分使用CNN编码文章的摘要。解码器部分是RNN以生成文章的标题。

所以骨架看起来像:

encoder_state = CNNEncoder(encoder_inputs)
decoder_outputs, _ = RNNDecoder(encoder_state,decoder_inputs)

但我想预先训练RNN解码器,教导模型先学会说话。解码器部分是:

def RNNDecoder(encoder_state,decoder_inputs):
    decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)
    #from tensorflow.models.rnn import rnn_cell, seq2seq
    cell = rnn.GRUCell(memory_dim)
    decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
        cell, decoder_inputs_embedded,
        initial_state=encoder_state,
        dtype=tf.float32,scope="plain_decoder1"
    )
    return decoder_outputs, decoder_final_state

所以我关心的是如何分别保存和恢复RNNDecoder部分?

1 个答案:

答案 0 :(得分:1)

在这里,您可以先获取动态RNN的输出。

decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(decoder_cell, decoder_inputs_embedded,initial_state=encoder_final_state,dtype=tf.float32, time_major=True, scope="plain_decoder")

选择 decoder_outputs 。然后使用softmax图层完全连接它。

decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_`size)

然后你可以用 decoder_logits 创建一个softmax损失并以正常的方式训练它。

如果要在会话中恢复参数

,请使用此类方法
with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(session, checkpoint_file)

此处检查点文件应该是您的确切检查点文件。因此,当运行所发生的事情时,它只会恢复您的解码器权重并使用主模型进行训练。