Tensorflow传输学习如何从一个检查点文件加载部分图层,并将所有图形变量保存在另一个检查点文件中

时间:2017-08-10 14:50:23

标签: tensorflow

在机器学习中使用tensorflow时,有时我们想要预先训练模型,并通过检查点文件(使用saver.save())将模型保存在pre-train_model floder中。然后我们希望使用预列车模型层的一部分来初始化新网络,并使用检查点文件将新列车模型保存在new-train_model文件夹中。

那么,我该怎么做才能实现这个功能。

1 个答案:

答案 0 :(得分:2)

Tensorflow: 当我们希望saver在我们的火车图中保存所有变量时,通常我们按如下方式定义保护程序:

saver = tf.train.Saver()

在init函数中,

__init__(self,
         vat_list=None
        ......
        )
self._var_list = var_list
.......
if self._var_list in None:
    self._var_list = variables._all_saveable_objects()
......

如果我们想从其他cnn模型加载2个图层,那么我们可以定义要恢复的变量列表并将其提供给保护程序对象,如下所示:

variables_to_restore = [var for var in tf.global_variables()
                        if var.name.startswith('conv_1')
                        or var.name.startswith('conv_2')]
saver = tf.train.Saver(variables_to_restore)
........
saver.restore(....)

这是一个例子,您可以根据自己的需要进行更改。

但是,在使用saver.save()将新模型保存到新模型文件夹后,如果要在下次恢复所有图形变量,则可以使用以下代码初始化保护对象: / p>

saver = tf.train.Saver()

之后,当您执行saver.restore()时,您很可能会看到“tensor node或vars无法找到”类型的错误。

要解决这些错误,您可以在转移模型时执行以下操作:

variables_to_restore = [var for var in tf.global_variables()
                        if var.name.startswith('conv_1')
                        or var.name.startswith('conv_2')]
saver = tf.train.Saver(variables_to_restore)
saver_all = tf.train.Saver()
........
saver.restore(....)
saver = saver_all
......
saver.save(<in new model folder>)