无法恢复张量流变量

时间:2019-07-24 15:39:58

标签: python tensorflow deep-learning

我有一个如下类,load函数返回张量流保存的图。

class StoredGraph():
     .
     .
     .

     def build_meta_saver(self, meta_file=None):
         meta_file = self._get_latest_checkpoint() + '.meta' if not meta_file else meta_file
         meta_saver = tf.train.import_meta_graph(meta_file)
         return meta_saver

    def load(self, sess, saverObj):
        saverObj.restore(sess, self._get_latest_checkpoint())
        graph = tf.get_default_graph()
        return graph

我还有另一个课程,叫它TrainNet()

class TrainNet():
    .
    .
    .
  def train(dataset):

    self.train_graph = tf.Graph()
    meta_saver, saver = None, None
    GraphIO = StoredGraph(experiment_dir)
    latest_checkpoint = GraphIO._get_latest_checkpoint()

    with self.train_graph.as_default():
        tf.set_random_seed(42)

        if not latest_checkpoint:
            #build graph here
             self.build_graph()
        else:
            meta_saver = GraphIO.build_meta_saver()  # this loads the meta file


        with tf.Session(graph=self.train_graph) as sess:
            if not meta_saver:
                sess.run(tf.global_variables_initializer())


            if latest_checkpoint:
                self.scaler, self.train_graph = GraphIO.load(sess, meta_saver)

            #here access placeholders using self.train_graph.get_tensor_by_name()... 
            #and feed the values

在我的培训课程中,我仅通过使用load函数作为self.train_graph = StoredGraphclass.load(sess,metasaver)加载图形来使用以上课程

我的疑问是通过加载保存的图形是否还原了所有变量?通常,每个人都在与saver.restore()相同的脚本中定义恢复操作,该脚本可恢复图的所有变量。但是我在另一个类中调用saver.restore(),并使用返回的图形访问占位符。

我认为并不是所有的变量都可以恢复。以上方法错误吗?当我检查以不同的训练步骤编写的两个不同的.meta文件中的权重值时,就产生了这种怀疑,并且权重值完全相同,这意味着此变量未更新或恢复方法存在某些故障。

1 个答案:

答案 0 :(得分:0)

只要您在文件中创建了所有必需的变量,并为其赋予了相同“名称”(当然形状也必须正确),restore会将所有适当的值加载到适当的变量中。 Here,您可以找到一个玩具示例,向您展示如何完成此操作。