优化忽略某些标签值的损失函数

时间:2019-08-28 13:38:01

标签: python tensorflow optimization loss-function

我正在为输入稀疏的数据编写一个二进制分类器,我想将输入0表示不存在数据,而不是该值肯定为0。我最初使用的是{{1 }},但它对假阳性的惩罚过于严厉。

我成功编写了一个损失函数,如下所示,该函数提供了我想要的行为,但是它慢了多个数量级,我需要找到一种方法来窃取一些性能。

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(...))

1 个答案:

答案 0 :(得分:0)

我相信我已经提出了一个解决方案。在我看来,我实际上不需要针对标签的数量进行标准化以获得平均值而不是总和,因此通过运行比较并将结果强制转换为条件,创建了{0,1}掩码浮动。然后,我使用矩阵乘法使用此掩码获取tf.nn.sigmoid_cross_entropy_with_logits(...)结果的点积,以便将条件为true的值添加到总和中,并将条件为false的值乘以零来取消它们。

这可能不是最好的解决方案,我将暂时搁置一个问题,以防有人可以就惯用或性能方面的好处提出更好的选择,但它可以满足我当前的需求

def loss(labels, logits):
    labels = tf.reshape(labels, shape=(-1,))
    logits = tf.reshape(logits, shape=(-1,))
    output = tf.reshape(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels,
            logits=logits),
        shape=(1, -1))

    mask = tf.reshape(
        tf.cast(labels > logits, tf.float32),  # True iff label==1
        shape=(-1, 1))

    return tf.matmul(output, mask)