在Saver documentation中,声明Saver对象可以将列表或字典作为输入,如果是dictionaris,则键必须是用于保存或恢复变量的名称。我有一个代码如下所示:
create_network()
vars_to_load_list = ...
vars_to_load_dict = {v.name:v for v in vars_to_load_list}
loader = tf.train.Saver(var_list=vars_to_load_list, max_to_keep=FLAGS.max_epoch)
path = ...
latest_ckpt = tf.train.latest_checkpoint(path, latest_filename=None)
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
loader.restore(sess, save_path=latest_ckpt)
上面的代码有效,但如果我传入变量字典而不是变量列表,即将loader
的定义更改为:
loader = tf.train.Saver(var_list=vars_to_load_dict, max_to_keep=FLAGS.max_epoch)
然后我得到一个NotFoundError
并且加载器抱怨在检查点文件中找不到一些Tensor名称。但我希望这两个版本的代码能够同样运行。我错过了什么吗?
答案 0 :(得分:3)
我弄明白了这个问题。显然,变量的name属性对应于变量的值而不是其张量(如果我对这些概念的理解是正确的)。即它返回"my_var:0"
,而加载程序需要"my_var"
。在上面的例子中修改字典的定义解决了这个问题:
vars_to_load_dict = {v.name[:-2]:v for v in vars_to_load_list}