Tensearn的Tensorflow argmax()?

时间:2017-05-11 15:42:30

标签: python mnist tflearn

我使用tflearn mnist 数据集来预测手写数字。

一切正常,但我的标签为 one_hot tflearn中是否有一个函数,它与Tensorflow中的argmax()相同?

1 个答案:

答案 0 :(得分:0)

你可以这样做:

pred = model.predict(test_data)

print([ np.where(r==1)[0][0] for r in np.round(pred) ])

最佳。