tf.train.init_from_checkpoint不会初始化使用tf.Variable创建的变量

时间:2019-02-27 12:16:22

标签: python tensorflow

似乎tf.train.init_from_checkpoint会初始化通过tf.get_variable创建的变量,而不会初始化通过tf.Variable创建的变量。

例如,让我们创建两个变量并将其保存:

import tensorflow as tf

tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saver.save(sess, './model', global_step=0)

如果我再次通过tf.train.Saver加载它们,则一切正常:即使在此处将变量初始化为零,变量也会重新加载为1:

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, './model-0')
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 1.0  bar: 1.0

但是如果我使用tf.train.init_from_checkpoint,我会得到

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 0.0  bar: 1.0

bar如预期般设置回1,但foo仍为0。

这是预期的行为吗?如果是这样,为什么?

1 个答案:

答案 0 :(得分:2)

是的,这是有意的。此行为在_init_from_checkpoint方法中进行了描述,该方法在加载要还原的变量时遍历分配图。

 for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None

它首先设置要恢复到None的变量,如果满足以下几个条件之一,它将重置为当前变量名。在这种情况下,循环包含语句

if "/" in current_var_or_name

因此,它将从较早创建的字典store_vars中加载变量。它是在_init_from_checkpoint检查分配映射中的当前变量是否为tf.Variable(此时为False)之后立即创建的。

 if _is_variable(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(_is_variable(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars 

store_vars由内部类_VariableStore创建,更确切地说,是由_get_default_variable_store()方法创建的。此类使用get_variable作为变量构造函数。由于tf.Variable没有默认范围,因此tf.get_variable首先调用tf.get_variable_scope(),它返回当前变量范围。 'foo'不在此范围内。除tf.Variable外,每次调用它时都会创建一个新变量,并且不允许共享。

store_vars是由默认作用域成员构造而成的,因此,它仅包含“ bar”变量,并且稍后将foo op添加到变量集合中tf.Variable

但是,如果assignment_map将包含{'foo':foo, 'bar':bar},则上述_init_from_checkpoint中的变量将找到并加载它们。因此,在这种情况下,您的代码将输出foo: 1.0 bar: 1.0

您可以在https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py

中找到代码

https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py 另请参阅此答案What is the default variable_scope in Tensorflow?