Tensorflow / CNN /测试一张图像

时间:2017-08-24 14:55:03

标签: image testing tensorflow prediction

我有一个简单的请求,但结果似乎很难达到。

我在Tensorflow中做了一个简单的模型,基于分为3个标签的图像(非常简单)。

数据输入后,模型经过培训+测试达到60%的良好结果。

然后,我想使用现有模型测试单个图像,看看模型是否猜出了正确的标签。

在整个测试过程中,整个过程非常简单:

#%%
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

test_image = mpimg.imread(os.path.join("images","cnn","cnn_test.png"))[:, :, :channels]
plt.imshow(test_image)
plt.axis("off")
plt.show()


X_test = test_image.reshape(-1,height, width, channels)

# initialize the variables
#sess.run(tf.global_variables_initializer())

X_test = test_image.reshape(-1, height, width, channels)

tf.reset_default_graph()


# x is the input array, which will contain the data from an image 

saver = tf.train.import_meta_graph('./02_CNN/data/my_model8.ckpt.meta')

print('meta graph imported')


X_test = test_image.reshape(-1, height, width, channels)

sess = tf.Session()
saver.restore(sess, './02_CNN/data/my_model8.ckpt')
print('model graph restored')

feed_dict = {x: X_test}

prediction = sess.run(feed_dict)


max_index = np.argmax(prediction)

print(max_index)

我有以下结果:

TypeError: Fetch argument array([[[[ 0.03137255,  0.20392157,  0.36862746],
         [ 0.02745098,  0.20784314,  0.37254903],
         [ 0.02352941,  0.21176471,  0.3764706 ],
         ..., 
      ....
         [ 0.08627451,  0.08235294,  0.09019608],
         [ 0.12941177,  0.12156863,  0.12941177],
         [ 0.13725491,  0.1254902 ,  0.13725491]]]], dtype=float32) 
has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. 
(Can not convert a ndarray into a Tensor or Operation.)

有关如何解决此问题的任何想法?

非常感谢,

尼古拉斯

0 个答案:

没有答案