TF LSTM:稍后从预测会话的训练会话中保存状态

时间:2017-04-27 16:56:43

标签: tensorflow lstm

我正在尝试将最新的LSTM状态从训练中保存下来,以便在预测阶段重复使用。我遇到的问题是在TF LSTM模型中,State通过占位符和numpy数组的组合从一个训练迭代传递到下一个 - 在会话中默认情况下这两者似乎都没有包含在Graph中保存了。

要解决这个问题,我正在创建一个专用的TF变量来保存最新版本的状态,以便将其添加到Session图中,如下所示:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

这似乎很好地将savedState变量添加到保存的会话图中,并且稍后可以在会话的其余部分轻松恢复。

问题是,我在恢复的Session中稍后设法实际使用该变量的唯一方法是,如果我在恢复它之后初始化会话中的所有变量(这似乎重置了所有训练过的变量,包括权重/偏见/等!)。如果我首先初始化变量然后恢复会话(这在保留训练的varialbes方面工作正常),那么我收到一个错误,我正在尝试访问未初始化的变量。

我知道有一种方法可以初始化一个特定的个体varialbe(我在保存它时使用它)但问题是当我们恢复它们时,我们通过名称将它们称为字符串,我们不只是通过变量本身?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

完成这项工作的正确方法是什么?作为一种解决方法,我目前正在将状态保存为CSV,就像一个numpy数组,然后以相同的方式恢复它。它工作正常,但显然不是最干净的解决方案,因为保存/恢复TF会话的其他方面都能正常工作。

任何建议都赞赏!

**编辑: 这里的代码运行良好,如下面接受的答案所述:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')


# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

我还没有测试过这种方法是否无意中初始化了已保存会话中的任何其他变量,但是不明白为什么应该这样做,因为我们只运行特定的变量。

1 个答案:

答案 0 :(得分:1)

问题是在构建tf.Variable之后创建新的Saver意味着Saver不知道新变量。它仍然保存在元图中,但未保存在检查点中:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

我已使用Saver知道的变量注释了上述问题的快速复制。

现在,解决方案相对容易。我建议您在Variable之前创建Saver,然后使用tf.assign更新其值(确保运行 tf.assign返回的操作)。分配的值将保存在检查点中,并像其他变量一样进行恢复。

Saver传递给None构造函数参数(即它可以自动获取新变量)时,var_list可以更好地处理这个问题。为此,请随意open a feature request on Github