使用现有的Tensorflow模型进行预测时出现问题

时间:2020-03-26 19:39:07

标签: python tensorflow machine-learning keras

我在一个数据集上训练了一个模型,该数据集包含属于两个不同类别的图像,现在正尝试从该模型对新图像进行一些预测。我使用saved_model格式进行保存,并试图在模型上加载和预测一张图像。我的代码如下

loaded = tf.keras.models.load_model('/Library/...')

loaded.compile(loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
               optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), metrics=['accuracy'])

test_image = image.load_img(img_path, target_size=(img_width, img_height))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
test_image = test_image.reshape(img_width, img_height)
result = loaded.predict(test_image)
print(loaded.predict(test_image))
print(result)

我得到了错误:

Traceback (most recent call last):
  File "/Users/...", line 73, in <module>
    test_image = test_image.reshape(img_width, img_height)
ValueError: cannot reshape array of size 268203 into shape (299,299)

我认为这是图像文件的问题,但是它与我以前训练的图像来自同一来源,并且在那里我没有遇到任何问题。所有文件都是RGB png图像(我认为问题是它们是RGBA,但事实并非如此)。任何帮助将不胜感激!

1 个答案:

答案 0 :(得分:1)

您必须检查图像尺寸。数组中元素的数量必须是.reshape参数的乘积。 299 * 299不等于268203。

一个例子是

a = np.arange(6)

这些是有效的重塑:

a.reshape(1,6)
a.reshape(2,3)
a.reshape(3,2)
a.reshape(6,1)

因为参数的乘积为6,即数组的长度。