从张量流检查点

时间:2018-05-11 18:07:44

标签: python tensorflow

我有张量流代码,我可以在其中保存并加载神经网络的模型。

def save(self, checkpoint_dir, step):
    model_name = "cyclegan.model"
    model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),
                    global_step=step)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoint...")

    model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        # ckpt_name = 'cyclegan.model-2052002'
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
        return True
    else:
        return False

我有一个检查点目录,其中包含以下文件:

checkpoint
cyclegan.model-2052002.data-00000-of-00001
cyclegan.model-2052002.index
cyclegan.model-2052002.meta
cyclegan.model-2053002.data-00000-of-00001
cyclegan.model-2053002.index
cyclegan.model-2053002.meta
cyclegan.model-2054002.data-00000-of-00001
cyclegan.model-2054002.index
cyclegan.model-2054002.meta
cyclegan.model-2055002.data-00000-of-00001
cyclegan.model-2055002.index
cyclegan.model-2055002.meta
cyclegan.model-2056002.data-00000-of-00001
cyclegan.model-2056002.index
cyclegan.model-2056002.meta 

如果我调用load函数,我猜它会加载最新的模型。请告诉我是否正确。我想要做的是从上面的列表中加载一个特定的模型。那么取消注释我为ckpt_name赋值的行是否有效?或者它仍然会加载最新型号。请帮忙。

1 个答案:

答案 0 :(得分:0)

取消注释该行应该加载引用的特定检查点。您可以参考tf.train.Saver.restore()