如何在SessionRunHook中使用tf.train.Saver?

时间:2018-03-28 12:48:15

标签: tensorflow

我已经训练了很多子模型,每个子模型都是最后一个模型的一部分。然后我想使用那些预训练的子模型来初始化最后一个模型的参数。我尝试使用SessionRunHook加载其他ckpt文件的模型参数来初始化最后一个模型。 我尝试了以下代码,但失败了。希望一些建议。谢谢! 错误信息是:

Traceback (most recent call last):
  File "train_high_api_local.py", line 282, in <module>
    tf.app.run()
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 124, in run
    _sys.exit(main(argv))
  File "train_high_api_local.py", line 266, in main
    clf_.train(input_fn=lambda: read_file([tables[0]], epochs_per_eval), steps=None, hooks=[hook_test])     # input yield: x, y
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 314, in train
  .......
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 674, in create_session
    hook.after_create_session(self.tf_sess, self.coord)
  File "train_high_api_local.py", line 102, in after_create_session
    saver = tf.train.Saver([ti])    # TODO: ERROR INFO:  Graph is finalized and cannot be modified.
  .......
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3135, in create_op
    self._check_not_finalized()
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2788, in _check_not_finalized
    raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.

,代码细节是:

class SetTensor(session_run_hook.SessionRunHook):
    """ like tf.train.LoggingTensorHook  """        
    def after_create_session(self, session, coord):
        """ Called when new TensorFlow session is created: graph is finalized and ops can no longer be added.  """
        graph = tf.get_default_graph()
        ti = graph.get_tensor_by_name("h_1_15/bias:0")
        with session.as_default():
            with tf.name_scope("rewrite"):
                saver = tf.train.Saver([ti])    # TODO: ERROR INFO:  Graph is finalized and cannot be modified.
                saver.restore(session, "/Users/zhouliaoming/data/credit_dnn/model_retrain/rm_gene_v2_sall/model.ckpt-2102")
        pass        

def main(unused_argv):
    """ train """
    norm_all_func = lambda x:  tf.cond(x>1, lambda: tf.log(x), lambda: tf.identity(x))
    feature_columns=[[tf.feature_column.numeric_column(COLUMNS[i], shape=fi, normalizer_fn=lambda x: tf.py_func(weight_norm2, [x], tf.float32) )] for i, fi in enumerate(FEA_DIM)]  # normlized: running OK!
    ## use self-defined model
    param = {"learning_rate": 0.0001, "feature_columns": feature_columns, "isanalysis": FLAGS.isanalysis, "isall": False}
    clf_ = tf.estimator.Estimator(model_fn=model_fn_wide2deep, params=param, model_dir=ckpt_dir)
    hook_test = SetTensor(["h_1_15/bias", "h_1_15/kernel"])
    epochs_per_eval = 1
    for n in range(int(FLAGS.num_epochs/epochs_per_eval)):
        # train num_epochs
        clf_.train(input_fn=lambda: read_file([tables[0]], epochs_per_eval), steps=None, hooks=[hook_test])     # input yield: x, y

2 个答案:

答案 0 :(得分:1)

SessionRunHook不适用于此用例。如错误所示,一旦调用sess.run(),您就无法更改图表。

您可以在“普通代码”中使用saver.restore()分配变量。你不必在任何钩子里面。

此外,如果要恢复许多变量并且可以将它们与检查点中的名称和形状相匹配,您可能需要查看https://gist.github.com/iganichev/d2d8a0b1abc6b15d4a07de83171163d4。它显示了一些恢复变量子集的示例代码。

答案 1 :(得分:1)

你可以这样做:

class SaveAtEnd(tf.train.SessionRunHook):
  def begin(self):
    self._saver = # create your saver

  def end(self, session):
    self._saver.save(session, ...)