TensorFlow:从多个检查点恢复变量

时间:2016-03-01 21:29:49

标签: tensorflow

我有以下情况:

  • 我有2个模型用2个独立的脚本编写:

  • 模型A由变量a1a2a3组成,并以A.py

  • 编写
  • 模型B由变量b1b2b3组成,并以B.py

  • 编写

A.pyB.py的每一个中,我有一个tf.train.Saver来保存所有局部变量的检查点,让我们调用检查点文件ckptA和{{ 1}}分别。

我现在想制作一个使用ckptBa1的模型C.我可以通过使用var_scope(和b1的相同内容)在A和C中使用a1的完全相同的变量名称。

问题是我如何从b1a1加载b1ckptA到模型C?例如,以下工作会如何?

ckptB

如果您尝试两次恢复同一个会话,是否会引发错误?它是否会抱怨额外变量(saver.restore(session, ckptA_location) saver.restore(session, ckptB_location) b2b3a2)没有分配“插槽”,或者只是简单地恢复变量,如果C中有其他未初始化的变量,只会抱怨?

我正在尝试编写一些代码来测试这个,但我很想看到这个问题的规范方法,因为在尝试重新使用一些预先训练的权重时经常会遇到这种情况。

谢谢!

1 个答案:

答案 0 :(得分:19)

如果您尝试使用保护程序(默认情况下代表所有六个变量)从不包含保护程序所代表的所有变量的检查点进行恢复,则会得到tf.errors.NotFoundError。 (但请注意,对于任何变量子集,只要所有请求的变量都存在于相应的文件中,您就可以在同一会话中多次调用Saver.restore()。)

规范方法是定义两个单独的tf.train.Saver实例,涵盖完全包含在单个检查点中的每个变量子集。例如:

saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])

saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)

根据代码的构建方式,如果您指向本地范围内名为tf.Variablea1的{​​{1}}个对象,则可以在此处停止阅读。

另一方面,如果变量b1a1在单独的文件中定义,您可能需要做一些有创意的事情来检索指向这些变量的指针。虽然它并不理想,但人们通常做的是使用公共前缀,例如如下(假设变量名分别为b1"a1:0"):

"b1:0"

最后一点说明:您不必为确保变量在A和C中具有相同的名称而做出英勇的努力。您可以将名称 - { - 1}}字典作为第一个参数传递给saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"]) saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"]) 构造函数,从而将检查点文件中的名称重新映射到代码中的Variable个对象。如果tf.train.SaverVariable具有类似命名的变量,或者A.py中您希望根据tf.name_scope()中的这些文件组织模型代码,这会有所帮助。