我使用最后一层的keras作为softmax,返回概率(0-1.0)浮点值。假设我有4个类,如何获得具有最高概率的类的索引?是否有keras,numpy或scikit-learn功能来执行此操作?
pred = model.predict(....)
# pred = [[0.9, 0.0, 0.0, 0.1], --------> [0, 1, 1]
# [0.1, 0.8, 0.1, 0.0], change to
# [0.1, 0.8, 0.1, 0.0]]
我想从0-1 float的数组更改为整数的原因是因为我想使用scikit-learn的混淆矩阵来显示精度,它只接受整数标签。