`tf.train.Saver`在哪里收集变量来保存?

时间:2018-03-28 07:27:34

标签: python tensorflow

我发现如果你想保存某个图形的所有变量,你必须在图形的最后一个定义tf.train.Saver,否则保护程序无法获得所有变量。

以下是我的测试代码:

def how_saver_work():
    g = tf.Graph()

    with g.as_default():

        a = tf.Variable(1, name='a')
        b = tf.Variable(2, name='b')

        print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
        saver = tf.train.Saver()

        c = tf.Variable(3, name='c')
        print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        with tf.Session() as sess:
            tf.global_variables_initializer().run()

            print(a.eval())
            print(b.eval())
            print(c.eval())

            save_path = saver.save(sess, "./tmp/model.ckpt")

    with tf.Session(graph=g) as sess:
        print(g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
        saver.restore(sess, save_path)

        print(a.eval())
        print(b.eval())
        print(c.eval())  # err: uninitialized

我首先认为保护程序可能会从tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.SAVEABLE_OBJECTS获取变量。但这似乎不对。

我还想知道如何向Saver添加新变量。例如var c

1 个答案:

答案 0 :(得分:1)

你是对的,saver确实从tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.SAVEABLE_OBJECTS的联合在其构造时间获得变量(参见Saver._build实施或下面的引用):

if self._var_list is None:
    # pylint: disable=protected-access
    self._var_list = variables._all_saveable_objects()

其中_all_saveable_objectspython/ops/variables.py文件中定义

def _all_saveable_objects(scope=None):
  """Returns all variables and `SaveableObject`s that must be checkpointed.

  Args:
    scope: (Optional.) A string. If supplied, the resulting list is filtered
      to include only items whose `name` attribute matches `scope` using
      `re.match`. Items without a `name` attribute are never returned if a
      scope is supplied. The choice of `re.match` means that a `scope` without
      special tokens filters by prefix.

  Returns:
    A list of `Variable` and `SaveableObject` to be checkpointed
  """
  # TODO(andreasst): make this function public once things are settled.
  return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
          ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))

除了创建新的保护程序之外,无法向保护程序添加变量(cf https://github.com/tensorflow/tensorflow/issues/2489#issuecomment-221282483):

  

当您创建一个没有参数的tf.train.Saver时,它会在保存和恢复时在保护程序构造时隐式使用当前变量集。如果添加新变量[...],则必须创建一个新的tf.train.Saver来保存它。