我想实现在Keras [http://ydwen.github.io/papers/WenECCV16.pdf]中解释的中心损失
我开始创建一个包含2个输出的网络,例如:
inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd',
loss=['categorical_crossentropy', 'center_loss'],
metrics=['accuracy'], loss_weights=[1., 0.2])
首先,这样做,这是继续进行的好方法吗?
其次,我不知道如何在keras中实现center_loss。 Center_loss看起来像均方误差,但不是将值与固定标签进行比较, 它将值与每次迭代时更新的数据进行比较。
感谢您的帮助
答案 0 :(得分:1)
对我来说,您可以按照以下步骤实施此图层:
写一个
的自定义图层ComputeCenter
需要两个输入:i)。 groudtruth标签y_true
(不是单热编码,只是整数)和ii)。预计会员资格y_pred
包含一个大小为W
数组的查找表num_classes x num_feats
作为可训练的权重(参见BatchNormalization Layer),W [j]是移动平均值的占位符第j课的特色。
计算论文中指定的中心损失。
D
要计算中心损失,您需要
W[j]
更新y_pred[k]
y_true[k]=j
c_true[k]=W[j]
的中心要素y_pred[k]
y_true[k]=j
和y_pred
之间的距离。c_true
和c_true[k] = W[j]
是示例索引,k
是y_pred [k]的基本事实标签。使用j
计算此损失。请注意,请勿在{{1}}中添加此丢失。
最后,如果需要,您可以在中心损失中添加一些损失系数。