如何打印出正确的预测类别?

时间:2019-09-26 07:17:44

标签: python python-3.x list tensorflow

当前能够获得正确的预测,但是打印的类别字符串错误。

如果只有2个类别,则此代码将完美运行,但是现在我使用3个类别。

CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]

prediction = model.predict([prepare('score3.png')])

print(prediction[0])  # will be a list in a list.

print(CATEGORIES[int(prediction[0][1])])

输出

[0. 0. 1.]

RGB images score 1

实际输出应为“ RGB图像得分3”。但是,我却获得了“ RGB图像得分1”。只有3张图片有此问题。

2 个答案:

答案 0 :(得分:0)

设法使用if else来解决它

CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]

prediction = model.predict([prepare('score3.png')])

if prediction[0][2] == 1.0:
    print("RGB images score 3")

else:
    print(CATEGORIES[int(prediction[0][1])])

答案 1 :(得分:0)

您正在看的prediction[0][1]有点像:网络是否预测了2

它确实适用于2个类别,但不适用于更多类别!您需要找到预测[0]等于1的索引。

您可以使用例如print(CATEGORIES[int(np.argmax(prediction[0]))])