如何在广义骰子损失计算(向后传递)中包括忽略标签以进行损失计算?

时间:2019-07-27 13:17:52

标签: tensorflow deep-learning conv-neural-network caffe

我已经编写了用于语义分割的广义Dice损失计算的代码,并且在不考虑ignore_label的情况下也可以很好地工作。现在,我尝试扩展到包括ignore_label,以在损失计算期间忽略特定标签。

我有一个形状为(1, 5, 93, 349, 219),标签为(1, 1, 93, 349, 219)的输入预测,我想ignore_label=1。在前向传递中,我创建一个mask=np.one_like(prediction),然后创建mask[:,ignore_label,:,:,:]=0。我先创建one_hot_label(形状:1, 5, 93, 349, 219),然后(在计算损失之前)

#I multiply mask separately to prediction and one_hot_label. 
gt_masked = np.multiply(one_hot_label,mask)
pred_masked = np.multiply(prediction,mask)

该特定ignore_label的正向传递损失将为零,并将被添加到其他类别的其他损失值中。

我的问题是我应该在倒数传递中将与ignore_label相关的张量通道放入什么值,以便计算梯度?

0 个答案:

没有答案