keras自定义损失分类错误数

时间:2018-12-19 17:10:39

标签: python tensorflow keras

我在tensorflow后端使用keras,并尝试编写一个自定义损失函数,该函数仅计算错误的分类预测数。这是我的尝试:

def error_count_loss(yTrue, yPred):
    """Sum and return the number of incorrect predictions.

    Parameters
    ----------
    yTrue : One-hot encoded truth
    yPred : Softmax encoded prediction
    """
    yTrue_argmax = K.argmax(yTrue, axis=1)
    yPred_argmax = K.argmax(yPred, axis=1)
    incorrect_bool = K.not_equal(yTrue_argmax, yPred_argmax)
    incorrect_float = K.cast(incorrect_bool, 'float32')
    return K.sum(incorrect_float)

此代码失败,因为argmax是不可区分的。有没有其他方法可以计算不正确的预测?

1 个答案:

答案 0 :(得分:1)

您可以尝试使用其他具有梯度的函数来编写接近于此的东西,例如:

def error_count_loss(yTrue, yPred):
    return K.sum(K.abs(K.sign(yTrue) - K.sign(yFalse)))

但这不是训练的最佳损失函数。尝试查看categorical_crossentropy