我的CNN产生了以下内容(来自model.predict()
):
Tensor("input_1:0", shape=(?, 2, 26, 1), dtype=float32)
[9.9952221e-01 2.3613637e-04 1.9953270e-06 1.6922619e-05 2.2012556e-04
2.4441533e-07 3.5276526e-07 7.4913805e-07 4.0657511e-07 8.7760031e-07]
我想从这个numpy数组中获取最大值的索引。现在,我已经尝试过这样做(x
是上面的数组):
result = x.index(max(x))
相反,这会引发一个错误,指出此数据类型不支持.index
?
答案 0 :(得分:0)
您可以简单地使用np.argmax
函数:
import numpy as np
preds = model.predict(test_data)
pred_class = np.argmax(preds, axis=-1)