我有一个深度多任务网络,旨在处理三个单独的分类任务。虽然每个训练示例都有两个分类任务的标签,但只有大约10%到15%的训练样本具有第三个任务的标签。网络具有多个共享层,随后是由一个或多个完全连接层和softmax / sigmoid输出层组成的每个任务的单独头部。
为了处理第三个任务中缺少的标签,我使用tf.boolean_mask来掩盖每个批次中没有标签的示例,除非在批处理没有训练示例的罕见情况下工作得很好标签;即整个批次中没有任务3的标签。在这种情况下,布尔掩码(正确)返回空张量,tf.softmax_cross_entropy_with_logits返回nan在训练期间引发错误。
我目前解决此问题的方法是检查批次是否没有第三项任务的标签,如果是,则在培训期间跳过批次。虽然这可以避免错误,但我想知道我是否可以编辑计算图来处理这种相对罕见的情况,因此我不必跳过批次。
以下是第三个任务的输出层和总损失函数的代码片段。此任务有完全连接的层,在此输出层之前有多个共享层。
# softmax output layer for natural categories
with tf.variable_scope('Natural_Category_Output'):
W = tf.get_variable('W', shape = [natural_layer_size, no_natural_categories],
initializer = tf.glorot_uniform_initializer())
b = tf.get_variable('b', shape = [no_natural_categories],
initializer = tf.glorot_uniform_initializer())
natural_logits = tf.add(tf.matmul(natural_output, W), b, name = 'logits')
masked_logits = tf.boolean_mask(natural_logits, natural_mask, axis = 0, name = 'masked_logits')
natural_probabilities = tf.nn.softmax(natural_logits, name = 'probabilities')
natural_predictions = tf.argmax(natural_logits, axis = 1, name = 'predictions')
masked_predictions = tf.boolean_mask(natural_predictions, natural_mask, axis = 0, name = 'masked_predictions')
# loss for the natural categories
with tf.variable_scope('Natural_Category_Loss_Function'):
masked_natural_category = tf.boolean_mask(binarized_natural_category, natural_mask, axis = 0, name = 'masked_natural_categories')
natural_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = masked_natural_category,
logits = masked_logits), name = 'cross_entropy_loss')
if uncertainty_weighting:
# intialize weight variables
natural_weight = tf.get_variable('natural_weight', shape = [], initializer = tf.constant_initializer(1.0))
# augment the the loss function for the task
natural_loss = tf.add(tf.divide(natural_loss, tf.multiply(tf.constant(2.0), tf.square(natural_weight))),
tf.log(tf.square(natural_weight)), name = 'weighted_loss')
# total loss function
with tf.variable_scope('Total_Loss'):
loss = fs_loss + expense_loss + natural_loss
有没有人有办法更改图表来处理没有标签的批次?
答案 0 :(得分:1)
基本上,您做对了。 另一种方法是在计算损失之前使用“ tf.gather”。 假设样本中没有标签,标签为“ -1”。
valid_idxs = tf.where(your_label > -1)[:, 0]
valid_logits = tf.gather(your_logits, valid_idxs)
valid_labels = tf.gather(your_label, valid_idxs)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=valid_labels, logits=valid_logits)