用于多标签分类的keras自定义指标

时间:2018-10-29 00:54:13

标签: tensorflow machine-learning keras deep-learning multilabel-classification

我正在使用sigmoidbinary_crossentropy进行多标签分类。

例如,y_true的标签就像[1,0,1,0,0],而y_pred的标签就像[0.8,0.3,0.9,0,0]

如何设置Keras自定义指标函数,以便将y_pred中大于0.5的每个元素映射为1,将y_pred中小于0.5的每个元素映射为0,然后比较数字y_pred中与y_true匹配的标签的数量?

1 个答案:

答案 0 :(得分:2)

由于您正在执行多标签分类,因此您似乎希望将整个真实标签和预测标签相互比较。例如,对于单个标签的真实标签为[1, 0, 0]并且预测标签为[0, 0, 0]的样本,您将预测准确性视为零(尽管已经预测了第二和第三类的标签正确)。在这种情况下,您可以比较标签,然后从后端使用all()方法来确保所有类的标签相互匹配:

from keras import backend as K

def full_multi_label_metric(y_true, y_pred):
    comp = K.equal(y_true, K.round(y_pred))
    return K.cast(K.all(comp, axis=-1), K.floatx())