我有一个如下类,load
函数返回张量流保存的图。
class StoredGraph():
.
.
.
def build_meta_saver(self, meta_file=None):
meta_file = self._get_latest_checkpoint() + '.meta' if not meta_file else meta_file
meta_saver = tf.train.import_meta_graph(meta_file)
return meta_saver
def load(self, sess, saverObj):
saverObj.restore(sess, self._get_latest_checkpoint())
graph = tf.get_default_graph()
return graph
我还有另一个课程,叫它TrainNet()
。
class TrainNet():
.
.
.
def train(dataset):
self.train_graph = tf.Graph()
meta_saver, saver = None, None
GraphIO = StoredGraph(experiment_dir)
latest_checkpoint = GraphIO._get_latest_checkpoint()
with self.train_graph.as_default():
tf.set_random_seed(42)
if not latest_checkpoint:
#build graph here
self.build_graph()
else:
meta_saver = GraphIO.build_meta_saver() # this loads the meta file
with tf.Session(graph=self.train_graph) as sess:
if not meta_saver:
sess.run(tf.global_variables_initializer())
if latest_checkpoint:
self.scaler, self.train_graph = GraphIO.load(sess, meta_saver)
#here access placeholders using self.train_graph.get_tensor_by_name()...
#and feed the values
在我的培训课程中,我仅通过使用load
函数作为self.train_graph = StoredGraphclass.load(sess,metasaver)
加载图形来使用以上课程
我的疑问是通过加载保存的图形是否还原了所有变量?通常,每个人都在与saver.restore()
相同的脚本中定义恢复操作,该脚本可恢复图的所有变量。但是我在另一个类中调用saver.restore()
,并使用返回的图形访问占位符。
我认为并不是所有的变量都可以恢复。以上方法错误吗?当我检查以不同的训练步骤编写的两个不同的.meta
文件中的权重值时,就产生了这种怀疑,并且权重值完全相同,这意味着此变量未更新或恢复方法存在某些故障。
答案 0 :(得分:0)
只要您在文件中创建了所有必需的变量,并为其赋予了相同“名称”(当然形状也必须正确),restore
会将所有适当的值加载到适当的变量中。 Here,您可以找到一个玩具示例,向您展示如何完成此操作。