在TensorFlow中,保存和恢复LSTM模型有什么不同吗?

时间:2017-10-01 07:45:38

标签: tensorflow lstm

我已尝试根据tutorial保存和恢复LSTM模型。实际上,它可以像CNN模型一样在普通模型中保存和恢复。但是,当我尝试恢复LSTM模型时,它会抛出错误

  

FailedPreconditionError(参见上面的回溯):尝试使用未初始化的值RNN_model / RNN / multi_rnn_cell / cell_0 / basic_lstm_cell / kernel

关键代码显示如下:

   with tf.Session() as sess:
    saver.restore(sess, saver_path)
    with tf.variable_scope('RNN_model', reuse=None):
        train_rnn = RNNmodel.LSTMmodel(True, RNNmodel.TRAIN_BATCH_SIZE, RNNmodel.NUM_STEP)

    with tf.variable_scope('RNN_model', reuse=True):
        test_rnn = RNNmodel.LSTMmodel(False, RNNmodel.EVAL_BATCH_SIZE, RNNmodel.NUM_STEP)

我想知道正常模型和LSTM模型在保存和恢复方面是否有任何区别。请帮忙

修改: 我尝试移动restore并且它有效,但是当我运行我的模型时,它仍然会抛出相同的错误,我的run_epoch代码如下:

def run_epoch(session, model, datas, train_op, is_log, epoch=3000):
    state = session.run(model.initiate_state)
    total_cost = 0
    for i in range(epoch):
        data, label = random_get_data(datas, model.batch_size, num_step=RNNmodel.NUM_STEP)
        feed_dict = {
            model.input_data: data,
            model.target: label,
            model.initiate_state: state
        }
        cost, state, argmax_logit, target, _ = session.run([model.loss, model.final_state, model.argmax_target, model.target, train_op], feed_dict)

日志将错误定位在:

cost, state, argmax_logit, target, _ = session.run([model.loss, model.final_state, model.argmax_target, model.target, train_op], feed_dict)

,日志显示如下:

  

tensorflow.python.framework.errors_impl.FailedPreconditionError:尝试使用未初始化的值RNN_model / RNN / multi_rnn_cell / cell_0 / basic_lstm_cell / kernel

似乎restore没有恢复lstm内核操作。我应该做什么来专门启动lstm操作?

EDIT2 : 我最后检查了checkpoint文件,我确信save操作不会保存关于LSTM细胞的变量,我不知道为什么。似乎我必须明确命名变量,否则我不能保存它,BasicLSTMCell类 init ()不具有name参数。

2 个答案:

答案 0 :(得分:0)

没有区别,RNN使用普通变量。

我认为你必须搬家

saver.restore(sess, saver_path)
创建LSTMmodel后

。否则,当您调用恢复时,它的变量不在图表中 - 因此它们将无法恢复。

答案 1 :(得分:0)

我终于明白了。根据{{​​3}}和 Jeronimo Garcia-Loygorri 的答案,我按照LSTM模型的定义移动Saver的创建,然后每个问题都消失了!