将词典传递给tensorflow Saver

时间:2016-10-28 15:56:35

标签: python-2.7 tensorflow

Saver documentation中,声明Saver对象可以将列表或字典作为输入,如果是dictionaris,则键必须是用于保存或恢复变量的名称。我有一个代码如下所示:

create_network()
vars_to_load_list = ...
vars_to_load_dict = {v.name:v for v in vars_to_load_list}
loader = tf.train.Saver(var_list=vars_to_load_list, max_to_keep=FLAGS.max_epoch)
path = ...
latest_ckpt = tf.train.latest_checkpoint(path, latest_filename=None)
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
  loader.restore(sess, save_path=latest_ckpt)

上面的代码有效,但如果我传入变量字典而不是变量列表,即将loader的定义更改为:

loader = tf.train.Saver(var_list=vars_to_load_dict, max_to_keep=FLAGS.max_epoch)

然后我得到一个NotFoundError并且加载器抱怨在检查点文件中找不到一些Tensor名称。但我希望这两个版本的代码能够同样运行。我错过了什么吗?

1 个答案:

答案 0 :(得分:3)

我弄明白了这个问题。显然,变量的name属性对应于变量的值而不是其张量(如果我对这些概念的理解是正确的)。即它返回"my_var:0",而加载程序需要"my_var"。在上面的例子中修改字典的定义解决了这个问题:

vars_to_load_dict = {v.name[:-2]:v for v in vars_to_load_list}