我有张量流代码,我可以在其中保存并加载神经网络的模型。
def save(self, checkpoint_dir, step):
model_name = "cyclegan.model"
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,
os.path.join(checkpoint_dir, model_name),
global_step=step)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoint...")
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
# ckpt_name = 'cyclegan.model-2052002'
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
我有一个检查点目录,其中包含以下文件:
checkpoint
cyclegan.model-2052002.data-00000-of-00001
cyclegan.model-2052002.index
cyclegan.model-2052002.meta
cyclegan.model-2053002.data-00000-of-00001
cyclegan.model-2053002.index
cyclegan.model-2053002.meta
cyclegan.model-2054002.data-00000-of-00001
cyclegan.model-2054002.index
cyclegan.model-2054002.meta
cyclegan.model-2055002.data-00000-of-00001
cyclegan.model-2055002.index
cyclegan.model-2055002.meta
cyclegan.model-2056002.data-00000-of-00001
cyclegan.model-2056002.index
cyclegan.model-2056002.meta
如果我调用load函数,我猜它会加载最新的模型。请告诉我是否正确。我想要做的是从上面的列表中加载一个特定的模型。那么取消注释我为ckpt_name赋值的行是否有效?或者它仍然会加载最新型号。请帮忙。