混淆矩阵仅考虑预测>阈值

时间:2020-05-30 11:39:09

标签: python machine-learning scikit-learn

我的网络的输出层是

model.add(Dense(2, activation=activations.softmax))

输出一个热编码类别预测。

model.predict因此返回n个预测,例如

[9.9584144e-01, 4.1585001e-03],
[7.5779420e-01, 2.4220583e-01],
...

我现在不仅想要sklearn提供的完整的混淆矩阵

metrics.confusion_matrix(y_TEST.argmax(axis=1), y_pred.argmax(axis=1), normalize='pred')

但是我想看看仅在最大预测值大于给定阈值的情况下,混淆矩阵才显示出来。

类似

metrics.confusion_matrix(y_TEST.argmax(axis=1), y_pred.argmax(axis=1), normalize='pred', 
                         min_confidence_threshold='0.9')

sklearn或此处的任何标准工具是否提供类似的功能?

如果没有,我如何根据其中一个数组的条件过滤两个数组(y_TEST,y_pred)?

1 个答案:

答案 0 :(得分:0)

以下是可以执行此操作的函数:

def threshold_matrix(labels,probs,threshold=0):
    assert 0 <= threshold <= 1
    t_map = lambda x : 1 if tf.math.reduce_max(x) > threshold else 0
    thres = tf.map_fn(t_map, probs)

    pred  = tf.boolean_mask(tf.argmax(probs,1),thres)
    labs  = tf.boolean_mask(tf.cast(labels, tf.int64),thres)

    return tf.math.confusion_matrix(labs,pred)