我正在尝试通过检查点文件将正式的resnet_v2_50预训练变量加载到我的tensorflow网络中,该网络由resnet_v2_50-> my_custom_classifier组成。在培训期间,我使用的代码库使用tf.train.MonitoredTrainingSession
,它可以自动处理检查点的保存和加载。
但是,MonitoredTrainingSession似乎没有选择将检查点加载到变量的子集中,然后初始化其余的任何其他变量,我想这样做,以便在未创建先前的检查点时,我可以使用resnet权重初始化模型。
通过在定义train_op之前保存resnet_v2_50变量:
self.resnet_vars = slim.get_variables_to_restore(include=['resnet_v2_50'], exclude=['classify'])
train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=
True)
我能够隔离需要更新的变量。
有了这些变量,如果我只想评估模型,就可以恢复预训练的权重:
def eval_mode(self, ckpt=None):
¦ with self.graph.as_default():
¦ ¦ self.session = tf.Session(config=utils.get_config())
¦ ¦ self.step = tf.global_variables_initializer()
¦ ¦ saver = tf.train.Saver(self.resnet_vars)
¦ ¦ ckpt = os.path.abspath(ckpt)
¦ ¦ saver.restore(self.session, ckpt)
¦ ¦ self.tmp.step = 0
¦ ¦ self.session.run(self.step)
return self
但是,尝试还原MonitoredTrainingSession中的变量会导致错误
Graph is finalized and cannot be modified
我已经尝试使用上面eval_mode中使用的代码以及使用
ckpt = os.path.abspat(ckpt)
tf.train.init_from_checkpoint(ckpt, {v.name.split(':')[0]: v for v in self.tmp_vars})
立即将会话初始化为MonitoredTrainingSession之后。