我正在尝试训练编码器 - 解码器模型以自动生成摘要。编码器部分使用CNN编码文章的摘要。解码器部分是RNN以生成文章的标题。
所以骨架看起来像:
encoder_state = CNNEncoder(encoder_inputs)
decoder_outputs, _ = RNNDecoder(encoder_state,decoder_inputs)
但我想预先训练RNN解码器,教导模型先学会说话。解码器部分是:
def RNNDecoder(encoder_state,decoder_inputs):
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)
#from tensorflow.models.rnn import rnn_cell, seq2seq
cell = rnn.GRUCell(memory_dim)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
cell, decoder_inputs_embedded,
initial_state=encoder_state,
dtype=tf.float32,scope="plain_decoder1"
)
return decoder_outputs, decoder_final_state
所以我关心的是如何分别保存和恢复RNNDecoder部分?
答案 0 :(得分:1)
在这里,您可以先获取动态RNN的输出。
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(decoder_cell, decoder_inputs_embedded,initial_state=encoder_final_state,dtype=tf.float32, time_major=True, scope="plain_decoder")
选择 decoder_outputs 。然后使用softmax图层完全连接它。
decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_`size)
然后你可以用 decoder_logits 创建一个softmax损失并以正常的方式训练它。
如果要在会话中恢复参数
,请使用此类方法with tf.Session() as session:
saver = tf.train.Saver()
saver.restore(session, checkpoint_file)
此处检查点文件应该是您的确切检查点文件。因此,当运行所发生的事情时,它只会恢复您的解码器权重并使用主模型进行训练。