Tensorflow在忽略范围名称或新范围名称时进行恢复

时间:2018-02-25 13:27:12

标签: python tensorflow

我首先训练了网络N并将其与保护程序一起保存到检查点Checkpoint_N中。在N中定义了一些变量范围。

现在,我想使用经过培训的网络N建立一个连体网络,如下所示:

with tf.variable_scope('siameseN',reuse=False) as scope:
  networkN = N()
  embedding_1 = networkN.buildN() 
  # this defines the network graph and all the variables.
  tf.train.Saver().restore(session_variable,Checkpoint_N)
  scope.reuse_variables()
  embedding_2 = networkN.buildN()
  # define 2nd branch of the Siamese, by reusing previously restored variables.

当我执行上述操作时,还原语句会在Key Error图表中的每个变量的检查点文件中抛出siameseN/conv1 N

有没有办法在不更改N代码的情况下执行此操作?我只是基本上为N中的每个变量和操作添加了一个父作用域。我可以通过告诉tensorflow忽略父范围或其他东西来将权重恢复到正确的变量吗?

2 个答案:

答案 0 :(得分:5)

这与:How to restore weights with different names but same shapes Tensorflow?

有关

tf.train.Saver(var_list={'variable_name_in_checkpoint':var_to_be_restored_to,...'})

可以列出要恢复的变量列表或字典

(e.g. 'variable_name_in_checkpoint':var_to_be_restored_to,...)

您可以通过浏览当前会话变量中的所有变量来准备上述字典,并使用会话变量作为值并获取当前变量的名称,并删除' siameseN /'从变量名称和用作键。它理论上应该有效。

答案 1 :(得分:1)

我不得不稍微更改代码,编写自己的恢复功能。我决定将检查点文件作为字典加载,变量名称作为键,相应的numpy数组作为值加载如下:

checkpoint_path = '/path/to/checkpoint'
from tensorflow.python import pywrap_tensorflow

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

key_to_numpy = {}
for key in var_to_shape_map:
  key_to_numpy[key] = reader.get_tensor(key)

我已经有了这个单一的函数,其中创建了所有变量,并且从图N调用了所需的名称。我修改它以使用从字典查找获得的numpy数组初始化变量。并且,为了使查找成功,我只是删除了我添加的父名称范围,如下所示:

init = tf.constant(key_to_numpy[ name.split('siameseN/')[1] ])
var = tf.get_variable(name,  initializer=init)
#var = tf.get_variable(name, shape, initializer=initializer)
return var

这是一个非常黑客的方法。我没有使用@edit的答案,因为我已经编写了上面的代码。此外,我的所有权重都是在一个函数中创建的,该函数将这些权重分配给变量var并返回它。因为这类似于函数式编程,变量var会被覆盖。 var永远不会暴露给更高级别的职能。要使用@ edit的答案,我必须为每次初始化使用不同的张量变量名称,并以某种方式将它们暴露给更高级别的函数,以便保护程序可以在它们的var_to_be_restored_to中使用它们答案。

但是@ edit的解决方案是不那么苛刻的解决方案,因为它遵循记录的用法。所以我接受了这个答案。我所做的可以是另一种解决方案。