缺少标签的深层多任务学习

时间:2018-03-11 01:03:59

标签: python tensorflow neural-network deep-learning

我有一个深度多任务网络,旨在处理三个单独的分类任务。虽然每个训练示例都有两个分类任务的标签,但只有大约10%到15%的训练样本具有第三个任务的标签。网络具有多个共享层,随后是由一个或多个完全连接层和softmax / sigmoid输出层组成的每个任务的单独头部。

为了处理第三个任务中缺少的标签,我使用tf.boolean_mask来掩盖每个批次中没有标签的示例,除非在批处理没有训练示例的罕见情况下工作得很好标签;即整个批次中没有任务3的标签。在这种情况下,布尔掩码(正确)返回空张量,tf.softmax_cross_entropy_with_logits返回nan在训练期间引发错误。

我目前解决此问题的方法是检查批次是否没有第三项任务的标签,如果是,则在培训期间跳过批次。虽然这可以避免错误,但我想知道我是否可以编辑计算图来处理这种相对罕见的情况,因此我不必跳过批次。

以下是第三个任务的输出层和总损失函数的代码片段。此任务有完全连接的层,在此输出层之前有多个共享层。

    # softmax output layer for natural categories
    with tf.variable_scope('Natural_Category_Output'):
        W = tf.get_variable('W', shape = [natural_layer_size, no_natural_categories], 
                            initializer = tf.glorot_uniform_initializer())
        b = tf.get_variable('b', shape = [no_natural_categories],
                            initializer = tf.glorot_uniform_initializer())

        natural_logits = tf.add(tf.matmul(natural_output, W), b, name = 'logits')
        masked_logits = tf.boolean_mask(natural_logits, natural_mask, axis = 0, name = 'masked_logits')

        natural_probabilities = tf.nn.softmax(natural_logits, name = 'probabilities')
        natural_predictions = tf.argmax(natural_logits, axis = 1, name = 'predictions')
        masked_predictions = tf.boolean_mask(natural_predictions, natural_mask, axis = 0, name = 'masked_predictions')

    # loss for the natural categories
    with tf.variable_scope('Natural_Category_Loss_Function'):
        masked_natural_category = tf.boolean_mask(binarized_natural_category, natural_mask, axis = 0, name = 'masked_natural_categories')
        natural_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = masked_natural_category, 
                                                                                 logits = masked_logits), name = 'cross_entropy_loss')

        if uncertainty_weighting:
            # intialize weight variables
            natural_weight = tf.get_variable('natural_weight', shape = [], initializer = tf.constant_initializer(1.0))

            # augment the the loss function for the task
            natural_loss = tf.add(tf.divide(natural_loss, tf.multiply(tf.constant(2.0), tf.square(natural_weight))),
                                  tf.log(tf.square(natural_weight)), name = 'weighted_loss')

    # total loss function 
    with tf.variable_scope('Total_Loss'):
        loss = fs_loss + expense_loss + natural_loss

有没有人有办法更改图表来处理没有标签的批次?

1 个答案:

答案 0 :(得分:1)

基本上,您做对了。 另一种方法是在计算损失之前使用“ tf.gather”。 假设样本中没有标签,标签为“ -1”。

valid_idxs = tf.where(your_label > -1)[:, 0]
valid_logits = tf.gather(your_logits, valid_idxs)
valid_labels = tf.gather(your_label, valid_idxs)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=valid_labels, logits=valid_logits)