似乎tf.train.init_from_checkpoint
会初始化通过tf.get_variable
创建的变量,而不会初始化通过tf.Variable
创建的变量。
例如,让我们创建两个变量并将其保存:
import tensorflow as tf
tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, './model', global_step=0)
如果我再次通过tf.train.Saver
加载它们,则一切正常:即使在此处将变量初始化为零,变量也会重新加载为1:
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './model-0')
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 1.0 bar: 1.0
但是如果我使用tf.train.init_from_checkpoint
,我会得到
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 0.0 bar: 1.0
bar
如预期般设置回1,但foo
仍为0。
这是预期的行为吗?如果是这样,为什么?
答案 0 :(得分:2)
是的,这是有意的。此行为在_init_from_checkpoint
方法中进行了描述,该方法在加载要还原的变量时遍历分配图。
for tensor_name_in_ckpt, current_var_or_name in sorted(
six.iteritems(assignment_map)):
var = None
它首先设置要恢复到None
的变量,如果满足以下几个条件之一,它将重置为当前变量名。在这种情况下,循环包含语句
if "/" in current_var_or_name
因此,它将从较早创建的字典store_vars
中加载变量。它是在_init_from_checkpoint
检查分配映射中的当前变量是否为tf.Variable
(此时为False)之后立即创建的。
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars
store_vars
由内部类_VariableStore
创建,更确切地说,是由_get_default_variable_store()
方法创建的。此类使用get_variable
作为变量构造函数。由于tf.Variable
没有默认范围,因此tf.get_variable
首先调用tf.get_variable_scope(),它返回当前变量范围。 'foo'不在此范围内。除tf.Variable
外,每次调用它时都会创建一个新变量,并且不允许共享。
store_vars
是由默认作用域成员构造而成的,因此,它仅包含“ bar”变量,并且稍后将foo
op添加到变量集合中tf.Variable
。
但是,如果assignment_map
将包含{'foo':foo, 'bar':bar}
,则上述_init_from_checkpoint
中的变量将找到并加载它们。因此,在这种情况下,您的代码将输出foo: 1.0 bar: 1.0
您可以在https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py
中找到代码和 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py 另请参阅此答案What is the default variable_scope in Tensorflow?