在Tensorflow中不同变量范围下保存/恢复权重

时间:2019-03-19 16:30:58

标签: python tensorflow

一段时间以来,我一直在尝试研究减轻体重的模型,但我仍然无法完全掌握它。我觉得我想做的事情应该很简单,但是我没有找到解决方案。

最终目标是通过一系列预训练的网络来进行传输标记。我将模型/图层写为类,因此用于保存权重和还原的类方法将是理想的。

示例: 如果我有一个图,功能> A> B>标签,其中A和B是子网,我想保存和/或恢复这些部分的权重。假设我已经掌握了A的权重,但是可变范围现在有所不同,如何从其他培训课程中恢复我为A训练的权重?训练完这个新图之后,我想要为我的新A权重分配1个目录,为我的新B权重分配1个目录,为完整图提供1个目录(我可以处理完整图位)。

很有可能我会一直忽略该解决方案,但是模型保存的文档却很少。

希望我已经很好地说明了这种情况。

1 个答案:

答案 0 :(得分:0)

您可以使用tf.train.init_from_checkpoint

定义模型

def model_fn():
    with tf.variable_scope('One'):
        layer = any_tf_layer
    with tf.variable_scope('Two'):
        layer = any_tf_layer 

在检查点文件中输出变量名称

vars = [i[0] for i in tf.train.list_variables(ckpt_file)]

然后,您可以创建分配图以仅加载模型中定义的变量。 您还可以为恢复的变量分配新名称

 map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars}

此行位于Estimator API的会话或外部模型函数之前

tf.train.init_from_checkpoint(ckpt_file, map)

https://www.tensorflow.org/api_docs/python/tf/train/init_from_checkpoint

您也可以使用tf.train.Saver 首先,您需要知道变量的名称

vars_dict = {}
for var_current in tf.global_variables():
    print(var_current)
    print(var_current.op.name) # this gets only name

for var_ckpt in tf.train.list_variables(ckpt):
    print(var_ckpt[0]) this gets only name

当您知道所有变量的确切名称时,只要变量具有相同的形状和dtype,就可以分配所需的任何值。因此要获得字典

vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 

saver = tf.train.Saver(vars_dict)

看看我对类似问题的其他答案 How to restore pretrained checkpoint for current model in Tensorflow?