Keras的中心损失

时间:2016-10-21 10:10:44

标签: keras

我想实现在Keras [http://ydwen.github.io/papers/WenECCV16.pdf]中解释的中心损失

我开始创建一个包含2个输出的网络,例如:

inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd', 
              loss=['categorical_crossentropy', 'center_loss'],
              metrics=['accuracy'], loss_weights=[1., 0.2])

首先,这样做,这是继续进行的好方法吗?

其次,我不知道如何在keras中实现center_loss。 Center_loss看起来像均方误差,但不是将值与固定标签进行比较, 它将值与每次迭代时更新的数据进行比较。

感谢您的帮助

1 个答案:

答案 0 :(得分:1)

对我来说,您可以按照以下步骤实施此图层:

  1. 写一个

    的自定义图层ComputeCenter
    • 需要两个输入:i)。 groudtruth标签y_true(不是单热编码,只是整数)和ii)。预计会员资格y_pred

    • 包含一个大小为W数组的查找表num_classes x num_feats作为可训练的权重(参见BatchNormalization Layer),W [j]是移动平均值的占位符第j课的特色。

    • 计算论文中指定的中心损失。

    • 输出生成的距离数组D
  2. 要计算中心损失,您需要

    • i)中。根据{{​​1}},
    • 使用W[j]更新y_pred[k]
    • ⅱ)。检索y_true[k]=j
    • 的样本c_true[k]=W[j]的中心要素y_pred[k]
    • iii)计算y_true[k]=jy_pred之间的距离。
    • 此处c_truec_true[k] = W[j]是示例索引,k是y_pred [k]的基本事实标签。
  3. 使用j计算此损失。请注意,请勿在{{1​​}}中添加此丢失。

  4. 最后,如果需要,您可以在中心损失中添加一些损失系数。