如何在MonitoredTrainingSession中初始化变量子集以在resnet_50中加载预训练的权重

时间:2019-05-27 06:19:17

标签: tensorflow

我正在尝试通过检查点文件将正式的resnet_v2_50预训练变量加载到我的tensorflow网络中,该网络由resnet_v2_50-> my_custom_classifier组成。在培训期间,我使用的代码库使用tf.train.MonitoredTrainingSession,它可以自动处理检查点的保存和加载。

但是,MonitoredTrainingSession似乎没有选择将检查点加载到变量的子集中,然后初始化其余的任何其他变量,我想这样做,以便在未创建先前的检查点时,我可以使用resnet权重初始化模型。

通过在定义train_op之前保存resnet_v2_50变量:

self.resnet_vars = slim.get_variables_to_restore(include=['resnet_v2_50'], exclude=['classify'])            
train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=
              True)  

我能够隔离需要更新的变量。

有了这些变量,如果我只想评估模型,就可以恢复预训练的权重:

def eval_mode(self, ckpt=None):                                                                              
     ¦   with self.graph.as_default():                                                                            
     ¦   ¦   self.session = tf.Session(config=utils.get_config())
     ¦   ¦   self.step = tf.global_variables_initializer()
     ¦   ¦   saver = tf.train.Saver(self.resnet_vars)
     ¦   ¦   ckpt = os.path.abspath(ckpt)      
     ¦   ¦   saver.restore(self.session, ckpt)   
     ¦   ¦   self.tmp.step = 0                                                                                    
     ¦   ¦   self.session.run(self.step)                    
     return self 

但是,尝试还原MonitoredTrainingSession中的变量会导致错误

Graph is finalized and cannot be modified

我已经尝试使用上面eval_mode中使用的代码以及使用

ckpt = os.path.abspat(ckpt)
tf.train.init_from_checkpoint(ckpt, {v.name.split(':')[0]: v for v in self.tmp_vars}) 

立即将会话初始化为MonitoredTrainingSession之后。

0 个答案:

没有答案