我正在尝试在Keras中实现损失函数,该函数可以执行以下操作:
假设y0,y1,...,yn是批量输入x0,x1,...,xn的模型批量输出,这里batch_size为n + 1,则输出yi每个xi是一个标量值,我希望损失函数能够为此批次计算整体损失,如下所示:
K.log(K.sigmoid(y1-y0))+ K.log(K.sigmoid(y2-y1))+ ... + K.log(K.sigmoid(yn-yn-1))
我当时想使用Lambda层首先将批处理输出[y0,y1,...,yn]转换为[y1-y0,y2-y1,...,yn-yn-1],然后使用转换后的输出上的自定义损失函数。
但是,我不确定Keras是否可以理解Lambda层中没有权重要更新,也不清楚 Keras如何将梯度传播回Lambda层,因为Keras通常要求每个图层/损失函数对单个样本输入进行操作,但是我的图层将采用一批样本的整个输出。有人解决过类似的问题吗?谢谢!
答案 0 :(得分:0)
下面的切片对您有用吗(尽管我没有使用keras)。
batch = 4
num_classes = 6
logits = tf.random.uniform(shape=[batch, num_classes])
logits1 = tf.slice(logits, (0, 0), [batch, num_classes-1])
logits2 = tf.slice(logits, (0, 1), [batch, num_classes-1])
delta = logits2 - logits1
loss = tf.reduce_sum(tf.log(tf.nn.sigmoid(delta)), axis=-1)
with tf.Session() as sess:
logits, logits1, logits2, delta, loss = sess.run([logits, logits1, logits2,
delta, loss])
print 'logits\n', logits
print 'logits2\n', logits2
print 'logits1\n', logits1
print 'delta\n', delta
print 'loss\n', loss
结果:
logits
[[ 0.61241663 0.70075285 0.98333454 0.4117974 0.5943476 0.84245574]
[ 0.02499413 0.22279179 0.70742595 0.34853518 0.7837007 0.88074362]
[ 0.35030317 0.36670768 0.64244425 0.87957716 0.22823489 0.45076978]
[ 0.38116801 0.39040041 0.82510674 0.64789391 0.45415008 0.03520513]]
logits2
[[ 0.70075285 0.98333454 0.4117974 0.5943476 0.84245574]
[ 0.22279179 0.70742595 0.34853518 0.7837007 0.88074362]
[ 0.36670768 0.64244425 0.87957716 0.22823489 0.45076978]
[ 0.39040041 0.82510674 0.64789391 0.45415008 0.03520513]]
logits1
[[ 0.61241663 0.70075285 0.98333454 0.4117974 0.5943476 ]
[ 0.02499413 0.22279179 0.70742595 0.34853518 0.7837007 ]
[ 0.35030317 0.36670768 0.64244425 0.87957716 0.22823489]
[ 0.38116801 0.39040041 0.82510674 0.64789391 0.45415008]]
delta
[[ 0.08833623 0.28258169 -0.57153714 0.18255019 0.24810815]
[ 0.19779766 0.48463416 -0.35889077 0.43516552 0.09704292]
[ 0.01640451 0.27573657 0.23713291 -0.65134227 0.22253489]
[ 0.0092324 0.43470633 -0.17721283 -0.19374382 -0.41894495]]
loss
[-3.41376281 -3.11249781 -3.49031925 -3.69255161]