在机器学习中使用tensorflow时,有时我们想要预先训练模型,并通过检查点文件(使用saver.save())将模型保存在pre-train_model floder中。然后我们希望使用预列车模型层的一部分来初始化新网络,并使用检查点文件将新列车模型保存在new-train_model文件夹中。
那么,我该怎么做才能实现这个功能。
答案 0 :(得分:2)
Tensorflow: 当我们希望saver在我们的火车图中保存所有变量时,通常我们按如下方式定义保护程序:
saver = tf.train.Saver()
在init函数中,
__init__(self,
vat_list=None
......
)
self._var_list = var_list
.......
if self._var_list in None:
self._var_list = variables._all_saveable_objects()
......
如果我们想从其他cnn模型加载2个图层,那么我们可以定义要恢复的变量列表并将其提供给保护程序对象,如下所示:
variables_to_restore = [var for var in tf.global_variables()
if var.name.startswith('conv_1')
or var.name.startswith('conv_2')]
saver = tf.train.Saver(variables_to_restore)
........
saver.restore(....)
这是一个例子,您可以根据自己的需要进行更改。
但是,在使用saver.save()将新模型保存到新模型文件夹后,如果要在下次恢复所有图形变量,则可以使用以下代码初始化保护对象: / p>
saver = tf.train.Saver()
之后,当您执行saver.restore()时,您很可能会看到“tensor node或vars无法找到”类型的错误。
要解决这些错误,您可以在转移模型时执行以下操作:
variables_to_restore = [var for var in tf.global_variables()
if var.name.startswith('conv_1')
or var.name.startswith('conv_2')]
saver = tf.train.Saver(variables_to_restore)
saver_all = tf.train.Saver()
........
saver.restore(....)
saver = saver_all
......
saver.save(<in new model folder>)