用于NMT推理的Tensorflow Restore LSTM

时间:2018-08-07 21:29:11

标签: tensorflow lstm seq2seq

我已经创建了一个用于机器翻译的模型,我想保存它并稍后加载以进行推理。以下是机器翻译模型的部分代码。

with tf.variable_scope("Encoder"):
    encoder_lstm = tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'encoder_lstm')
    enc_outputs, enc_states = tf.nn.dynamic_rnn(encoder_lstm, encoder_emb_input, time_major = False, dtype = tf.float32)

with tf.variable_scope("Decoder_with_Attention"):
    attention_states = enc_outputs
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(latent_dim, attention_states)

    decoder_lstm = tf.contrib.seq2seq.AttentionWrapper(tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'decoder_lstm'), attention_mechanism, attention_layer_size = latent_dim, name = 'wrapper')
    helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_input, decoder_seq_len, time_major = False)
    projection_layer = Dense(spa_features, use_bias = False, name = 'projection_layer')
    decoder = tf.contrib.seq2seq.BasicDecoder(decoder_lstm, helper, decoder_lstm.zero_state(dtype = tf.float32, batch_size = batch_size).clone(cell_state = enc_states), output_layer = projection_layer)
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
    outputs = outputs.rnn_output

要使用保存的模型进行推理,我需要使用编码器和解码器LSTM。保存模型后,如何恢复这些LSTM?我尝试执行以下操作:

with tf.variable_scope("Encoder"):
    encoder_lstm = tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'encoder_lstm')
with tf.variable_scope("Decoder"):
    decoder_lstm = tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'decoder_lstm')

sess = tf.Session()
saver = tf.train.import_meta_graph("model.meta")
saver.restore(sess, tf.train.latest_checkpoint("./"))
graph = tf.get_default_graph()

但是,这不会为编码器和解码器LSTM加载经过训练的权重。我该如何解决?

0 个答案:

没有答案