tensorflow - 无法恢复模型 - “无法匹配检查点的文件”

时间:2017-10-09 17:53:48

标签: tensorflow save restore

这是我的模型保存到磁盘:

import tensorflow as tf
import numpy as np


BATCH_SIZE = 3
VECTOR_SIZE = 1
LEARNING_RATE = 0.1

x = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
                   name='input_placeholder')
y = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
                   name='labels_placeholder')

W = tf.get_variable('W', [VECTOR_SIZE, BATCH_SIZE])
b = tf.get_variable('b', [VECTOR_SIZE], initializer=tf.constant_initializer(0.0))

y_hat = tf.matmul(W, x) + b
predict = tf.add(tf.matmul(W, x), b, name='predict')
total_loss = tf.reduce_mean(y-y_hat)
train_step = tf.train.AdagradOptimizer(LEARNING_RATE).minimize(total_loss)
X = np.ones([BATCH_SIZE, VECTOR_SIZE])
Y = np.ones([BATCH_SIZE, VECTOR_SIZE])
all_saver = tf.train.Saver() 

sess= tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([train_step], feed_dict = {x: X, y:Y})
save_path = r'C:\tmp\tmp\\'
all_saver.save(sess,save_path)

尝试恢复时

checkpoint_path = r'C:\tmp\tmp\\'
tf.train.latest_checkpoint(checkpoint_path)

我收到以下错误消息:

ERROR:tensorflow:Couldn't match files for checkpoint C:\tmp\tmp\\

C:\tmp\tmp\我有以下文件:

.data-00000-of-00001
.index
.meta
checkpoint

有什么想法吗?

3 个答案:

答案 0 :(得分:3)

来自saver.save tensorflow api:

  

save_path:String。检查点文件名的路径。如果保存程序是分片的,则这是分片检查点文件名的前缀。

save_path中,您没有指定检查点文件名。

为了将来使用,请尝试设置: checkpoint_path = r'C:\tmp\tmp\my-model'

如果要加载以前保存的模型,请执行以下操作:

  1. 为这些文件添加字符串my-model
  2. .data-00000-of-00001
    .index
    .meta
    
    1. 修改checkpoint文件,使其指向您的检查点:
    2. model_checkpoint_path: "C:\tmp\tmp\my-model"
      all_model_checkpoint_paths: "C:\tmp\tmp\my-model"
      

      现在应该可以加载检查点。

答案 1 :(得分:2)

文件刚刚命名为行吗?从点开始?

如果是这种情况,您应该考虑以不同方式保存它们,因为这可能是问题所在。

尝试:

NUMBER_OF_CKPT = 60 saver.save(sess,save_path,global_step=NUMBER_OF_CKPT)

通常做的是将global_step保存为ckpt的编号。

希望能解决它!

答案 2 :(得分:1)

FWIW我在AI平台(Cloud ML Engine)上训练自定义估算器时看到此错误。对我来说,这个问题是由我保存检查点/模型元数据的GCS存储桶中的region引起的。

当此存储桶的region设置为us (multiple regions in United States)时,我在评估期间看到此错误。将GCS存储桶的region设置为运行AI Platform作业的region(在我的情况下为us-central1 (Iowa))可以解决此问题。