我正在努力恢复模型一天,但没有成功。我的代码由class TF_MLPRegressor()
组成,我在其中定义了构造函数中的网络体系结构。然后,我调用fit()
函数进行训练。因此,这就是我在fit()
函数中保存带有1个隐藏层的简单Perceptron模型的方式:
starting_epoch = 0
# Launch the graph
tf.set_random_seed(self.random_state) # fix the random seed before creating the Session in order to take effect!
if hasattr(self, 'sess'):
self.sess.close()
del self.sess # delete Session to release memory
gc.collect()
self.sess = tf.Session(config=self.config) # save the session to predict from new data
# Create a saver object which will save all the variables
saver = tf.train.Saver(max_to_keep=2) # max_to_keep=2 means to not keep more than 2 checkpoint files
self.sess.run(tf.global_variables_initializer())
# ... (each 100 epochs)
saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)
然后,我使用完全相同的输入参数值创建一个新的TF_MLPRegressor()
实例,并调用fit()
函数来恢复模型,如下所示:
self.sess = tf.Session(config=self.config) # create a new session to load saved variables
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
starting_epoch = int(ckpt.split('-')[-1])
metagraph = ".".join([ckpt, 'meta'])
saver = tf.train.import_meta_graph(metagraph)
self.sess.run(tf.global_variables_initializer()) # Initialize variables
lhl = tf.trainable_variables()[2]
lhlA = lhl.eval(session=self.sess)
saver.restore(sess=self.sess, save_path=ckpt) # Restore model weights from previously saved model
lhlB = lhl.eval(session=self.sess)
print lhlA == lhlB
lhlA
和lhlB
是还原前后的最后一个隐藏层权重,根据我的代码,它们完全匹配,即已保存的模型不会加载到会话中。我在做什么错了?
答案 0 :(得分:1)
我找到了解决方法!奇怪的是,元图不包含我定义或为其分配的新名称的所有变量。例如,在构造函数中,我定义了张量,这些张量将携带输入特征向量和实验值:
self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')
但是,当我执行tf.reset_default_graph()
并加载元图时,会得到以下变量列表:
[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>,
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>,
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>,
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]
为进行记录,每个输入特征向量具有300个特征。无论如何,当我稍后尝试使用以下方法来开始训练时:
_, c, p = self.sess.run([self.optimizer, self.cost, self.pred],
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})
我收到如下错误:
"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."
因此,由于每次创建class TF_MLPRegressor()
的实例时,我都在构造函数中定义了网络体系结构,因此我决定不加载该元图并且它起作用了!我不知道为什么TF不能将所有变量都保存到元图中,可能是因为我像下面的示例一样明确定义了网络架构(我不使用包装器或默认层):
总而言之,我按照我的第一条消息中的描述保存模型,但要恢复它们,我使用以下方法:
saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config) # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt) # Restore model weights from previously saved model