如何从张量流检查点文件正确恢复网络训练?

时间:2018-09-02 11:39:41

标签: python tensorflow deep-learning restore checkpoint

我正在努力恢复模型一天,但没有成功。我的代码由class TF_MLPRegressor()组成,我在其中定义了构造函数中的网络体系结构。然后,我调用fit()函数进行训练。因此,这就是我在fit()函数中保存带有1个隐藏层的简单Perceptron模型的方式:

            starting_epoch = 0
            # Launch the graph
            tf.set_random_seed(self.random_state)   # fix the random seed before creating the Session in order to take effect!
            if hasattr(self, 'sess'):
                self.sess.close()
                del self.sess   # delete Session to release memory
                gc.collect()
            self.sess = tf.Session(config=self.config) # save the session to predict from new data
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(max_to_keep=2)  # max_to_keep=2 means to not keep more than 2 checkpoint files
            self.sess.run(tf.global_variables_initializer())

# ... (each 100 epochs)

            saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)

然后,我使用完全相同的输入参数值创建一个新的TF_MLPRegressor()实例,并调用fit()函数来恢复模型,如下所示:

    self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
    ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
    starting_epoch = int(ckpt.split('-')[-1])
    metagraph = ".".join([ckpt, 'meta'])
    saver = tf.train.import_meta_graph(metagraph)
    self.sess.run(tf.global_variables_initializer())    # Initialize variables
    lhl = tf.trainable_variables()[2]
    lhlA = lhl.eval(session=self.sess)
    saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
    lhlB = lhl.eval(session=self.sess)
    print lhlA == lhlB

lhlAlhlB是还原前后的最后一个隐藏层权重,根据我的代码,它们完全匹配,即已保存的模型不会加载到会话中。我在做什么错了?

1 个答案:

答案 0 :(得分:1)

我找到了解决方法!奇怪的是,元图不包含我定义或为其分配的新名称的所有变量。例如,在构造函数中,我定义了张量,这些张量将携带输入特征向量和实验值:

self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')

但是,当我执行tf.reset_default_graph()并加载元图时,会得到以下变量列表:

[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>, 
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>, 
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>, 
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]

为进行记录,每个输入特征向量具有300个特征。无论如何,当我稍后尝试使用以下方法来开始训练时:

_, c, p = self.sess.run([self.optimizer, self.cost, self.pred], 
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})

我收到如下错误:

"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."

因此,由于每次创建class TF_MLPRegressor()的实例时,我都在构造函数中定义了网络体系结构,因此我决定不加载该元图并且它起作用了!我不知道为什么TF不能将所有变量都保存到元图中,可能是因为我像下面的示例一样明确定义了网络架构(我不使用包装器或默认层):

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

总而言之,我按照我的第一条消息中的描述保存模型,但要恢复它们,我使用以下方法:

saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model