我的网络的输出层是
model.add(Dense(2, activation=activations.softmax))
输出一个热编码类别预测。
model.predict
因此返回n个预测,例如
[9.9584144e-01, 4.1585001e-03],
[7.5779420e-01, 2.4220583e-01],
...
我现在不仅想要sklearn提供的完整的混淆矩阵
metrics.confusion_matrix(y_TEST.argmax(axis=1), y_pred.argmax(axis=1), normalize='pred')
但是我想看看仅在最大预测值大于给定阈值的情况下,混淆矩阵才显示出来。
类似
metrics.confusion_matrix(y_TEST.argmax(axis=1), y_pred.argmax(axis=1), normalize='pred',
min_confidence_threshold='0.9')
sklearn或此处的任何标准工具是否提供类似的功能?
如果没有,我如何根据其中一个数组的条件过滤两个数组(y_TEST,y_pred)?
答案 0 :(得分:0)
以下是可以执行此操作的函数:
def threshold_matrix(labels,probs,threshold=0):
assert 0 <= threshold <= 1
t_map = lambda x : 1 if tf.math.reduce_max(x) > threshold else 0
thres = tf.map_fn(t_map, probs)
pred = tf.boolean_mask(tf.argmax(probs,1),thres)
labs = tf.boolean_mask(tf.cast(labels, tf.int64),thres)
return tf.math.confusion_matrix(labs,pred)