我试图基于2个模型的中间层的输出来计算损耗,然后我应用“平均”层来获取输出。我不太确定该怎么做。
损失函数在this paper中定义,我已经实现了以下损失函数,导致恒定损失。
关于如何解决此问题的任何想法?谢谢。
def sim(y_true, y_pred):
tao = 1
z1 = y_pred[:16]
z2 = y_pred[16:]
z1_norm = K.sqrt(K.sum(z1 ** 2))
z2_norm = K.sqrt(K.sum(z2 ** 2))
sij = K.exp(K.dot(z1, K.transpose(z2)) / (tao * z1_norm * z2_norm))
sji = K.exp(K.dot(z2, K.transpose(z1)) / (tao * z1_norm * z2_norm))
lij = -K.log((sij) / (K.sum(sij, axis=0) + K.epsilon()))
lji = -K.log((sji) / (K.sum(sji, axis=0) + K.epsilon()))
return K.mean(lij + lji)
16/5216 [..............................] - ETA: 3:07:28 - loss: 5.5452
32/5216 [..............................] - ETA: 1:34:14 - loss: 5.5452
48/5216 [..............................] - ETA: 1:03:08 - loss: 5.5452
64/5216 [..............................] - ETA: 47:35 - loss: 5.5452
...
5184/5216 [============================>.] - ETA: 0s - loss: 5.5452
5200/5216 [============================>.] - ETA: 0s - loss: 5.5452
5216/5216 [==============================] - 131s 25ms/step - loss: 5.5452 - val_loss: 5.5452
我的模特是
base1 = ResNet50(weights='imagenet', include_top=False, input_shape = (224,224,3))
base2 = ResNet50(weights='imagenet', include_top=False, input_shape = (224,224,3))
for i, layer in enumerate(base1.layers):
layer.name += '_model1'
for i, layer in enumerate(base2.layers):
layer.name += '_model2'
f1 = base1.output
f2 = base2.output
f1 = Dropout(0.5)(f1)
f2 = Dropout(0.5)(f2)
h1 = GlobalAveragePooling2D()(f1)
h2 = GlobalAveragePooling2D()(f2)
z1 = Dense(256, activation='relu', name='z1')(h1)
z2 = Dense(256, activation='relu', name='z2')(h2)
pred = Concatenate(axis=0)([z1, z2])
model = Model(inputs = [base1.input, base2.input], output = pred)
model.compile(optimizer=Adam(3e-4), loss=sim)
model.fit([x1, x2], y, epochs=1, batch_size=32)
其他信息-我已经打印了张量z1,z2的值,它们似乎对于每个批次都是不同的。
576/5216 [==>...........................] - ETA: 7:50 - loss: 5.5452
z1 is[[0 1.89108324 1.09692252 0 0 0 0 0 0.551534235 0...]...]
z2 is[[0.583885849 0 0 0 0 2.4857161 0.370762467 0 0 0...]...]
592/5216 [==>...........................] - ETA: 7:38 - loss: 5.5452
z2 is[[0.981608689 0 0 0 0.487944722 4.01521063 1.57426798 0.254898 0 0...]...]
z1 is[[0 1.78056419 2.08047867 0 0.0729936138 0 0 0 0 0.225044087...]...]
608/5216 [==>...........................] - ETA: 7:27 - loss: 5.5452
z2 is[[0.917239964 0 0 0 0 2.9613061 0.337595135 0.418117255 0 0...]...]
z1 is[[0 1.5660435 0.462322 0.0269343853 0.481997907 0 0 0 0.871785402 0.672020376...]...]