当前能够获得正确的预测,但是打印的类别字符串错误。
如果只有2个类别,则此代码将完美运行,但是现在我使用3个类别。
CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]
prediction = model.predict([prepare('score3.png')])
print(prediction[0]) # will be a list in a list.
print(CATEGORIES[int(prediction[0][1])])
输出
[0. 0. 1.]
RGB images score 1
实际输出应为“ RGB图像得分3”。但是,我却获得了“ RGB图像得分1”。只有3张图片有此问题。
答案 0 :(得分:0)
设法使用if else来解决它
CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]
prediction = model.predict([prepare('score3.png')])
if prediction[0][2] == 1.0:
print("RGB images score 3")
else:
print(CATEGORIES[int(prediction[0][1])])
答案 1 :(得分:0)
您正在看的prediction[0][1]
有点像:网络是否预测了2
。
它确实适用于2个类别,但不适用于更多类别!您需要找到预测[0]等于1的索引。
您可以使用例如print(CATEGORIES[int(np.argmax(prediction[0]))])