我正在使用sigmoid
和binary_crossentropy
进行多标签分类。
例如,y_true
的标签就像[1,0,1,0,0]
,而y_pred
的标签就像[0.8,0.3,0.9,0,0]
。
如何设置Keras自定义指标函数,以便将y_pred
中大于0.5的每个元素映射为1,将y_pred
中小于0.5的每个元素映射为0,然后比较数字y_pred
中与y_true
匹配的标签的数量?
答案 0 :(得分:2)
由于您正在执行多标签分类,因此您似乎希望将整个真实标签和预测标签相互比较。例如,对于单个标签的真实标签为[1, 0, 0]
并且预测标签为[0, 0, 0]
的样本,您将预测准确性视为零(尽管已经预测了第二和第三类的标签正确)。在这种情况下,您可以比较标签,然后从后端使用all()
方法来确保所有类的标签相互匹配:
from keras import backend as K
def full_multi_label_metric(y_true, y_pred):
comp = K.equal(y_true, K.round(y_pred))
return K.cast(K.all(comp, axis=-1), K.floatx())