多类分割的广义骰子损失:keras实现

时间:2018-02-27 15:14:12

标签: machine-learning deep-learning keras loss-function semantic-segmentation

我刚刚在keras中实现了广义骰子丢失(骰子丢失的多级版本),如ref中所述:

(我的目标定义为:(batch_size,image_dim1,image_dim2,image_dim3,nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

但事情肯定是错的。我正在使用3D图像,我必须分为4个类(1个背景类和3个对象类,我有一个不平衡的数据集)。首先是奇怪的事情:虽然我的训练失误和准确度在训练期间有所改善(并且收敛速度非常快),但我的验证损失/准确度在时期内是恒定的(见image)。其次,在预测测试数据时,只预测背景类:我得到一个恒定的音量。

我使用了完全相同的数据和脚本,但是使用了分类交叉熵损失并获得了合理的结果(对象类被分段)。这意味着我的实施出了问题。知道它可能是什么吗?

另外我相信,对于keras社区来说,有一个广义的骰子丢失实现是有用的,因为它似乎在大多数最近的语义分割任务中使用(至少在医学图像社区中)。

PS:对我来说,如何定义权重似乎很奇怪;我得到大约10 ^ -10的值。还有其他人试图实现这个吗?我也在没有重量的情况下测试了我的功能但是遇到了同样的问题。

1 个答案:

答案 0 :(得分:1)

我认为这是您的体重问题。假设您正在尝试解决多类分割问题,但是在每个图像中,只有少数几个类存在。一个玩具的例子(以及导致我这个问题的例子)是通过以下方式从mnist创建细分数据集。

x = 28x28图像,y = 28x28x11,其中如果每个像素均低于归一化灰度值0.4,则将其分类为背景,否则将其分类为x的原始类别。因此,如果您看到一张数字为1的图片,则将得到一堆分类为1的像素以及背景。

现在,在此数据集中,图像中将只存在两个类。这意味着,在您损失骰子之后,将有9个重量 1./(0. + eps) = large 因此,对于每张图片,我们都会严厉惩罚所有9个不存在的类。在这种情况下,网络希望找到的一个明显强的局部最小值是将一切作为背景知识进行预测。

我们确实希望对不在图像中但没有那么强烈的任何错误预测的类进行惩罚。因此,我们只需要修改权重。这是我的方法:

def gen_dice(y_true, y_pred, eps=1e-6):
    """both tensors are [b, h, w, classes] and y_pred is in logit form"""

    # [b, h, w, classes]
    pred_tensor = tf.nn.softmax(y_pred)
    y_true_shape = tf.shape(y_true)

    # [b, h*w, classes]
    y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
    y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])

    # [b, classes]
    # count how many of each class are present in 
    # each image, if there are zero, then assign
    # them a fixed weight of eps
    counts = tf.reduce_sum(y_true, axis=1)
    weights = 1. / (counts ** 2)
    weights = tf.where(tf.math.is_finite(weights), weights, eps)

    multed = tf.reduce_sum(y_true * y_pred, axis=1)
    summed = tf.reduce_sum(y_true + y_pred, axis=1)

    # [b]
    numerators = tf.reduce_sum(weights*multed, axis=-1)
    denom = tf.reduce_sum(weights*summed, axis=-1)
    dices = 1. - 2. * numerators / denom
    dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
    return tf.reduce_mean(dices)