我正在尝试在张量流中实现多类骰子损失功能。由于它是多类骰子,因此我需要将每个类的概率转换为一种热门形式。例如,如果我的网络输出以下概率:
[0.2、0.6、0.1、0.1](假设有4个班级)
我需要将其转换为:
[0 1 0 0]
这可以通过使用 tf.argmax 后跟 tf.one_hot
def generalized_dice_loss(labels, logits):
#labels shape [batch_size,128,128,64,1] dtype=float32
#logits shape [batch_size,128,128,64,7] dtype=float32
labels=tf.cast(labels,tf.int32)
smooth = tf.constant(1e-17)
shape = tf.TensorShape(logits.shape).as_list()
depth = int(shape[-1])
labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
labels = tf.squeeze(labels, axis=5)
logits = tf.argmax(logits,axis=4)
logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
numerator=tf.cast(numerator,tf.float32)
denominator=tf.cast(denominator,tf.float32)
loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
return loss
问题是,tf.argmax是不可区分的,它将引发错误:
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
如何解决这个问题?我们可以不使用tf.argmax做同样的事情吗?
答案 0 :(得分:0)
看看How is the smooth dice loss differentiable?。您无需进行转换(将[0.2, 0.6, 0.1, 0.1]
转换为[0 1 0 0]
)。只需将它们保留为0到1之间的连续值即可。
如果我正确理解,则损失函数只是实现您预期目标的替代方法。即使它不相同,只要它是一个很好的代理,就可以了(否则,是不可区分的)。
在评估期间,可以随时使用tf.argmax
来获取真实指标。