Tensorflow:在丢失操作中按批次的某些条件标签过滤

时间:2017-07-31 16:49:40

标签: python tensorflow

我的网络中有两个标签批次。第一批是嘈杂的标签,第二批是经过验证的标签。并非所有噪声标签都有经过验证的标签,但批次具有相同的大小。必须仅在经过验证的标签上计算损失。有没有办法过滤批次以便只在损失中使用经过验证的标签?

这是我的损失定义:

def loss_clean_v1(label_output, label_verified):
    loss_clean_value = tf.reduce_sum(tf.abs(tf.subtract(label_output, label_verified)))
    # debug
    loss_clean_value = tf.Print(loss_clean_value, [loss_clean_value], message="Loss label cleaning: ")
    return loss_clean_value

1 个答案:

答案 0 :(得分:0)

我找到了解决问题的方法!

label是从清洁到一批嘈杂标签获得的一批输出标签。 label_verified是一批经过验证的标签;并非每个嘈杂的标签都有经过验证的标签,因此我使用全-1 -1 np.full((1, 6012), -1)的假标签。

在我的遗失中,必须仅使用经过验证的标签。我需要过滤label_verified批次。

为此,我使用带有布尔向量的tf.where(),其大小是批处理中的标签数量,条件为[False, True]

问题是Tensorflow只提供了元素方式的比较运算符,它们不会产生我需要的布尔向量。

为了获得该向量,我改进了元素条件结果:condition_refined = tf.cast(tf.reduce_sum(tf.cast(condition, tf.int32), 1), tf.bool)

过滤的批量结果为label_filtered = tf.where(condition_refined, label_verified, label)

如果条件为False,我会使用标签输出,因为我的损失包含已验证和输出标签的减法。因此输出标签 - 输出标签为损失贡献0。

这是我的代码:

def loss_clean(label, label_verified):

    # in np.full, 2 is my batch size 
    condition = tf.not_equal(label_verified, np.full((2, 6012), -1))

    # debug
    condition = tf.Print(condition, [condition], message="Condition: ")

    condition_refined = tf.cast(tf.reduce_sum(tf.cast(condition, tf.int32), 1), tf.bool)

    # debug
    condition_refined = tf.Print(condition_refined, [condition_refined], message="Condition_refined: ")

    label_filtered = tf.where(condition_refined, label_verified, label)

    # debug
    label_filtered = tf.Print(label_filtered, [label_filtered], message="Label_filtered: ")

    loss_clean_value = tf.reduce_sum(tf.abs(tf.subtract(label, label_filtered)))

    # debug
    loss_clean_value = tf.Print(loss_clean_value, [loss_clean_value], message="Loss label cleaning: ")

    return loss_clean_value