在Tensorflow中恢复检查点 - 未找到张量名称

时间:2016-09-06 20:37:36

标签: tensorflow

尝试使用Google here提供的体系结构和检查点运行Inceptionv3 Tensorflow模型。

我的问题是我的脚本在saver.restore(sess, "./inception_v3.ckpt")崩溃时出现以下错误:

tensorflow.python.framework.errors.NotFoundError: Tensor name "InceptionV3/Mixed_5b/Branch_1/Conv2d_0b_5x5/biases" not found in checkpoint files ./inception_v3.ckpt

这是我的代码:

import tensorflow as tf
import inception_v3

with tf.Session() as sess:
  image = tf.read_file('./file.jpg')
  # code to decode, crop, convert jpeg
  eval_inputs = tf.pack([image])
  logits, _ = inception_v3.inception_v3(eval_inputs, num_classes=1001, is_training=False)
  sess.run(tf.initialize_all_variables())

  saver = tf.train.Saver()
  saver.restore(sess, "./inception_v3.ckpt")

我在其他检查点/模型组合中遇到相同的错误,所以这必须是我的代码的问题。不知道我做错了什么。

谢谢

2 个答案:

答案 0 :(得分:0)

确实,检查点文件不包含此张量。你能在github上提交一个bug吗?

答案 1 :(得分:0)

您需要在inception_v3()返回的arg_scope()内拨打inception_v3_arg_scope(),如下所示:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets.inception_v3 import inception_v3, inception_v3_arg_scope

height = 299
width = 299
channels = 3

# Create graph
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_v3_arg_scope()):
    logits, end_points = inception_v3(X, num_classes=1001,
                                      is_training=False)
predictions = end_points["Predictions"]
saver = tf.train.Saver()

X_test = ... # your images, shape [batch_size, 299, 299, 3]

# Execute graph
with tf.Session() as sess:
    saver.restore(sess, "./inception_v3.ckpt")
    predictions_val = predictions.eval(feed_dict={X: X_test})

predicted_classes = np.argmax(predictions_val, axis=1)

我建议明确区分构造阶段和执行阶段。刚刚在网络上的随机照片上进行了测试,它运行良好。 :)