我首先训练了网络N
并将其与保护程序一起保存到检查点Checkpoint_N
中。在N
中定义了一些变量范围。
现在,我想使用经过培训的网络N
建立一个连体网络,如下所示:
with tf.variable_scope('siameseN',reuse=False) as scope:
networkN = N()
embedding_1 = networkN.buildN()
# this defines the network graph and all the variables.
tf.train.Saver().restore(session_variable,Checkpoint_N)
scope.reuse_variables()
embedding_2 = networkN.buildN()
# define 2nd branch of the Siamese, by reusing previously restored variables.
当我执行上述操作时,还原语句会在Key Error
图表中的每个变量的检查点文件中抛出siameseN/conv1
N
。
有没有办法在不更改N
代码的情况下执行此操作?我只是基本上为N
中的每个变量和操作添加了一个父作用域。我可以通过告诉tensorflow忽略父范围或其他东西来将权重恢复到正确的变量吗?
答案 0 :(得分:5)
这与:How to restore weights with different names but same shapes Tensorflow?
有关 tf.train.Saver(var_list={'variable_name_in_checkpoint':var_to_be_restored_to,...'})
可以列出要恢复的变量列表或字典
(e.g. 'variable_name_in_checkpoint':var_to_be_restored_to,...)
您可以通过浏览当前会话变量中的所有变量来准备上述字典,并使用会话变量作为值并获取当前变量的名称,并删除' siameseN /'从变量名称和用作键。它理论上应该有效。
答案 1 :(得分:1)
我不得不稍微更改代码,编写自己的恢复功能。我决定将检查点文件作为字典加载,变量名称作为键,相应的numpy数组作为值加载如下:
checkpoint_path = '/path/to/checkpoint'
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
key_to_numpy = {}
for key in var_to_shape_map:
key_to_numpy[key] = reader.get_tensor(key)
我已经有了这个单一的函数,其中创建了所有变量,并且从图N
调用了所需的名称。我修改它以使用从字典查找获得的numpy数组初始化变量。并且,为了使查找成功,我只是删除了我添加的父名称范围,如下所示:
init = tf.constant(key_to_numpy[ name.split('siameseN/')[1] ])
var = tf.get_variable(name, initializer=init)
#var = tf.get_variable(name, shape, initializer=initializer)
return var
这是一个非常黑客的方法。我没有使用@edit的答案,因为我已经编写了上面的代码。此外,我的所有权重都是在一个函数中创建的,该函数将这些权重分配给变量var
并返回它。因为这类似于函数式编程,变量var
会被覆盖。 var
永远不会暴露给更高级别的职能。要使用@ edit的答案,我必须为每次初始化使用不同的张量变量名称,并以某种方式将它们暴露给更高级别的函数,以便保护程序可以在它们的var_to_be_restored_to
中使用它们答案。
但是@ edit的解决方案是不那么苛刻的解决方案,因为它遵循记录的用法。所以我接受了这个答案。我所做的可以是另一种解决方案。