我发现如果你想保存某个图形的所有变量,你必须在图形的最后一个定义tf.train.Saver
,否则保护程序无法获得所有变量。
以下是我的测试代码:
def how_saver_work():
g = tf.Graph()
with g.as_default():
a = tf.Variable(1, name='a')
b = tf.Variable(2, name='b')
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
saver = tf.train.Saver()
c = tf.Variable(3, name='c')
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(a.eval())
print(b.eval())
print(c.eval())
save_path = saver.save(sess, "./tmp/model.ckpt")
with tf.Session(graph=g) as sess:
print(g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
saver.restore(sess, save_path)
print(a.eval())
print(b.eval())
print(c.eval()) # err: uninitialized
我首先认为保护程序可能会从tf.GraphKeys.GLOBAL_VARIABLES
或tf.GraphKeys.SAVEABLE_OBJECTS
获取变量。但这似乎不对。
我还想知道如何向Saver添加新变量。例如var c
。
答案 0 :(得分:1)
你是对的,saver
确实从tf.GraphKeys.GLOBAL_VARIABLES
和tf.GraphKeys.SAVEABLE_OBJECTS
的联合在其构造时间获得变量(参见Saver._build
实施或下面的引用):
if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()
其中_all_saveable_objects
在python/ops/variables.py
文件中定义
def _all_saveable_objects(scope=None):
"""Returns all variables and `SaveableObject`s that must be checkpointed.
Args:
scope: (Optional.) A string. If supplied, the resulting list is filtered
to include only items whose `name` attribute matches `scope` using
`re.match`. Items without a `name` attribute are never returned if a
scope is supplied. The choice of `re.match` means that a `scope` without
special tokens filters by prefix.
Returns:
A list of `Variable` and `SaveableObject` to be checkpointed
"""
# TODO(andreasst): make this function public once things are settled.
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
除了创建新的保护程序之外,无法向保护程序添加变量(cf https://github.com/tensorflow/tensorflow/issues/2489#issuecomment-221282483):
当您创建一个没有参数的
tf.train.Saver
时,它会在保存和恢复时在保护程序构造时隐式使用当前变量集。如果添加新变量[...],则必须创建一个新的tf.train.Saver
来保存它。