是否可以编写一个基于Keras中批次中样本输出差异的自定义损失函数?

时间:2019-05-12 04:18:07

标签: tensorflow keras

我正在尝试在Keras中实现损失函数,该函数可以执行以下操作:

假设y0,y1,...,yn是批量输入x0,x1,...,xn的模型批量输出,这里batch_size为n + 1,则输出yi每个xi是一个标量值,我希望损失函数能够为此批次计算整体损失,如下所示:

K.log(K.sigmoid(y1-y0))+ K.log(K.sigmoid(y2-y1))+ ... + K.log(K.sigmoid(yn-yn-1))

我当时想使用Lambda层首先将批处理输出[y0,y1,...,yn]转换为[y1-y0,y2-y1,...,yn-yn-1],然后使用转换后的输出上的自定义损失函数。

但是,我不确定Keras是否可以理解Lambda层中没有权重要更新,也不清楚 Keras如何将梯度传播回Lambda层,因为Keras通常要求每个图层/损失函数对单个样本输入进行操作,但是我的图层将采用一批样本的整个输出。有人解决过类似的问题吗?谢谢!

1 个答案:

答案 0 :(得分:0)

下面的切片对您有用吗(尽管我没有使用keras)。

batch = 4
num_classes = 6
logits = tf.random.uniform(shape=[batch, num_classes])

logits1 = tf.slice(logits, (0, 0), [batch, num_classes-1])
logits2 = tf.slice(logits, (0, 1), [batch, num_classes-1])

delta = logits2 - logits1
loss = tf.reduce_sum(tf.log(tf.nn.sigmoid(delta)), axis=-1)

with tf.Session() as sess:
  logits, logits1, logits2, delta, loss  = sess.run([logits, logits1, logits2, 
                                                     delta, loss])

  print 'logits\n', logits
  print 'logits2\n', logits2
  print 'logits1\n', logits1
  print 'delta\n', delta
  print 'loss\n', loss

结果:

logits
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits2
[[ 0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits1
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476 ]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007 ]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008]]
delta
[[ 0.08833623  0.28258169 -0.57153714  0.18255019  0.24810815]
 [ 0.19779766  0.48463416 -0.35889077  0.43516552  0.09704292]
 [ 0.01640451  0.27573657  0.23713291 -0.65134227  0.22253489]
 [ 0.0092324   0.43470633 -0.17721283 -0.19374382 -0.41894495]]
loss
[-3.41376281 -3.11249781 -3.49031925 -3.69255161]