尝试使用Google here提供的体系结构和检查点运行Inceptionv3 Tensorflow模型。
我的问题是我的脚本在saver.restore(sess, "./inception_v3.ckpt")
崩溃时出现以下错误:
tensorflow.python.framework.errors.NotFoundError: Tensor name "InceptionV3/Mixed_5b/Branch_1/Conv2d_0b_5x5/biases" not found in checkpoint files ./inception_v3.ckpt
这是我的代码:
import tensorflow as tf
import inception_v3
with tf.Session() as sess:
image = tf.read_file('./file.jpg')
# code to decode, crop, convert jpeg
eval_inputs = tf.pack([image])
logits, _ = inception_v3.inception_v3(eval_inputs, num_classes=1001, is_training=False)
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.restore(sess, "./inception_v3.ckpt")
我在其他检查点/模型组合中遇到相同的错误,所以这必须是我的代码的问题。不知道我做错了什么。
谢谢
答案 0 :(得分:0)
确实,检查点文件不包含此张量。你能在github上提交一个bug吗?
答案 1 :(得分:0)
您需要在inception_v3()
返回的arg_scope()
内拨打inception_v3_arg_scope()
,如下所示:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets.inception_v3 import inception_v3, inception_v3_arg_scope
height = 299
width = 299
channels = 3
# Create graph
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_v3_arg_scope()):
logits, end_points = inception_v3(X, num_classes=1001,
is_training=False)
predictions = end_points["Predictions"]
saver = tf.train.Saver()
X_test = ... # your images, shape [batch_size, 299, 299, 3]
# Execute graph
with tf.Session() as sess:
saver.restore(sess, "./inception_v3.ckpt")
predictions_val = predictions.eval(feed_dict={X: X_test})
predicted_classes = np.argmax(predictions_val, axis=1)
我建议明确区分构造阶段和执行阶段。刚刚在网络上的随机照片上进行了测试,它运行良好。 :)