与keras的facenet三胞胎损失

时间:2016-12-10 13:18:04

标签: neural-network tensorflow keras

我正在尝试使用Thensorflow后端在Keras中实现facenet,我对三元组丢失有一些问题。enter image description here

我用3 * n个图像调用fit函数,然后按如下方式定义自定义丢失函数:

def triplet_loss(self, y_true, y_pred):

    embeddings = K.reshape(y_pred, (-1, 3, output_dim))

    positive_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,1]),axis=-1)
    negative_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,2]),axis=-1)
    return K.mean(K.maximum(0.0, positive_distance - negative_distance + _alpha))

self._model.compile(loss=triplet_loss, optimizer="sgd")
self._model.fit(x=x,y=y,nb_epoch=1, batch_size=len(x))

其中y只是一个填充0s的虚拟数组

问题在于即使在批量大小为20的第一次迭代之后,模型也开始为所有图像预测相同的嵌入。因此,当我第一次对批处理进行预测时,每次嵌入都是不同的。然后我做了拟合并再次预测,突然所有嵌入对于批处理中的所有图像变得几乎相同

另请注意,模型末尾有一个Lambda图层。它规范了网络的输出,因此所有嵌入都具有单面长度,如面网研究中所建议的那样。

有人可以帮我吗?

模型摘要

    Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 112, 112, 64)  9472        input_1[0][0]                    
____________________________________________________________________________________________________
batchnormalization_1 (BatchNormal(None, 112, 112, 64)  128         convolution2d_1[0][0]            
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 56, 56, 64)    0           batchnormalization_1[0][0]       
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 56, 56, 64)    4160        maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
batchnormalization_2 (BatchNormal(None, 56, 56, 64)    128         convolution2d_2[0][0]            
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 56, 56, 192)   110784      batchnormalization_2[0][0]       
____________________________________________________________________________________________________
batchnormalization_3 (BatchNormal(None, 56, 56, 192)   384         convolution2d_3[0][0]            
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D)    (None, 28, 28, 192)   0           batchnormalization_3[0][0]       
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D)  (None, 28, 28, 96)    18528       maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_7 (Convolution2D)  (None, 28, 28, 16)    3088        maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
maxpooling2d_3 (MaxPooling2D)    (None, 28, 28, 192)   0           maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D)  (None, 28, 28, 64)    12352       maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_6 (Convolution2D)  (None, 28, 28, 128)   110720      convolution2d_5[0][0]            
____________________________________________________________________________________________________
convolution2d_8 (Convolution2D)  (None, 28, 28, 32)    12832       convolution2d_7[0][0]            
____________________________________________________________________________________________________
convolution2d_9 (Convolution2D)  (None, 28, 28, 32)    6176        maxpooling2d_3[0][0]             
____________________________________________________________________________________________________
merge_1 (Merge)                  (None, 28, 28, 256)   0           convolution2d_4[0][0]            
                                                                   convolution2d_6[0][0]            
                                                                   convolution2d_8[0][0]            
                                                                   convolution2d_9[0][0]            
____________________________________________________________________________________________________
convolution2d_11 (Convolution2D) (None, 28, 28, 96)    24672       merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_13 (Convolution2D) (None, 28, 28, 32)    8224        merge_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_4 (MaxPooling2D)    (None, 28, 28, 256)   0           merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_10 (Convolution2D) (None, 28, 28, 64)    16448       merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_12 (Convolution2D) (None, 28, 28, 128)   110720      convolution2d_11[0][0]           
____________________________________________________________________________________________________
convolution2d_14 (Convolution2D) (None, 28, 28, 64)    51264       convolution2d_13[0][0]           
____________________________________________________________________________________________________
convolution2d_15 (Convolution2D) (None, 28, 28, 64)    16448       maxpooling2d_4[0][0]             
____________________________________________________________________________________________________
merge_2 (Merge)                  (None, 28, 28, 320)   0           convolution2d_10[0][0]           
                                                                   convolution2d_12[0][0]           
                                                                   convolution2d_14[0][0]           
                                                                   convolution2d_15[0][0]           
____________________________________________________________________________________________________
convolution2d_16 (Convolution2D) (None, 28, 28, 128)   41088       merge_2[0][0]                    
____________________________________________________________________________________________________
convolution2d_18 (Convolution2D) (None, 28, 28, 32)    10272       merge_2[0][0]                    
____________________________________________________________________________________________________
convolution2d_17 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_16[0][0]           
____________________________________________________________________________________________________
convolution2d_19 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_18[0][0]           
____________________________________________________________________________________________________
maxpooling2d_5 (MaxPooling2D)    (None, 14, 14, 320)   0           merge_2[0][0]                    
____________________________________________________________________________________________________
merge_3 (Merge)                  (None, 14, 14, 640)   0           convolution2d_17[0][0]           
                                                                   convolution2d_19[0][0]           
                                                                   maxpooling2d_5[0][0]             
____________________________________________________________________________________________________
convolution2d_21 (Convolution2D) (None, 14, 14, 96)    61536       merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_23 (Convolution2D) (None, 14, 14, 32)    20512       merge_3[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_6 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_20 (Convolution2D) (None, 14, 14, 256)   164096      merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_22 (Convolution2D) (None, 14, 14, 192)   166080      convolution2d_21[0][0]           
____________________________________________________________________________________________________
convolution2d_24 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_23[0][0]           
____________________________________________________________________________________________________
convolution2d_25 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_6[0][0]             
____________________________________________________________________________________________________
merge_4 (Merge)                  (None, 14, 14, 640)   0           convolution2d_20[0][0]           
                                                                   convolution2d_22[0][0]           
                                                                   convolution2d_24[0][0]           
                                                                   convolution2d_25[0][0]           
____________________________________________________________________________________________________
convolution2d_27 (Convolution2D) (None, 14, 14, 112)   71792       merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_29 (Convolution2D) (None, 14, 14, 32)    20512       merge_4[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_7 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_26 (Convolution2D) (None, 14, 14, 224)   143584      merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_28 (Convolution2D) (None, 14, 14, 224)   226016      convolution2d_27[0][0]           
____________________________________________________________________________________________________
convolution2d_30 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_29[0][0]           
____________________________________________________________________________________________________
convolution2d_31 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_7[0][0]             
____________________________________________________________________________________________________
merge_5 (Merge)                  (None, 14, 14, 640)   0           convolution2d_26[0][0]           
                                                                   convolution2d_28[0][0]           
                                                                   convolution2d_30[0][0]           
                                                                   convolution2d_31[0][0]           
____________________________________________________________________________________________________
convolution2d_33 (Convolution2D) (None, 14, 14, 128)   82048       merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_35 (Convolution2D) (None, 14, 14, 32)    20512       merge_5[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_8 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_32 (Convolution2D) (None, 14, 14, 192)   123072      merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_34 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_33[0][0]           
____________________________________________________________________________________________________
convolution2d_36 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_35[0][0]           
____________________________________________________________________________________________________
convolution2d_37 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_8[0][0]             
____________________________________________________________________________________________________
merge_6 (Merge)                  (None, 14, 14, 640)   0           convolution2d_32[0][0]           
                                                                   convolution2d_34[0][0]           
                                                                   convolution2d_36[0][0]           
                                                                   convolution2d_37[0][0]           
____________________________________________________________________________________________________
convolution2d_39 (Convolution2D) (None, 14, 14, 144)   92304       merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_41 (Convolution2D) (None, 14, 14, 32)    20512       merge_6[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_9 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_38 (Convolution2D) (None, 14, 14, 160)   102560      merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_40 (Convolution2D) (None, 14, 14, 288)   373536      convolution2d_39[0][0]           
____________________________________________________________________________________________________
convolution2d_42 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_41[0][0]           
____________________________________________________________________________________________________
convolution2d_43 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_9[0][0]             
____________________________________________________________________________________________________
merge_7 (Merge)                  (None, 14, 14, 640)   0           convolution2d_38[0][0]           
                                                                   convolution2d_40[0][0]           
                                                                   convolution2d_42[0][0]           
                                                                   convolution2d_43[0][0]           
____________________________________________________________________________________________________
convolution2d_44 (Convolution2D) (None, 14, 14, 160)   102560      merge_7[0][0]                    
____________________________________________________________________________________________________
convolution2d_46 (Convolution2D) (None, 14, 14, 64)    41024       merge_7[0][0]                    
____________________________________________________________________________________________________
convolution2d_45 (Convolution2D) (None, 7, 7, 256)     368896      convolution2d_44[0][0]           
____________________________________________________________________________________________________
convolution2d_47 (Convolution2D) (None, 7, 7, 128)     204928      convolution2d_46[0][0]           
____________________________________________________________________________________________________
maxpooling2d_10 (MaxPooling2D)   (None, 7, 7, 640)     0           merge_7[0][0]                    
____________________________________________________________________________________________________
merge_8 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_45[0][0]           
                                                                   convolution2d_47[0][0]           
                                                                   maxpooling2d_10[0][0]            
____________________________________________________________________________________________________
convolution2d_49 (Convolution2D) (None, 7, 7, 192)     196800      merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_51 (Convolution2D) (None, 7, 7, 48)      49200       merge_8[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_11 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_48 (Convolution2D) (None, 7, 7, 384)     393600      merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_50 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_49[0][0]           
____________________________________________________________________________________________________
convolution2d_52 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_51[0][0]           
____________________________________________________________________________________________________
convolution2d_53 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_11[0][0]            
____________________________________________________________________________________________________
merge_9 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_48[0][0]           
                                                                   convolution2d_50[0][0]           
                                                                   convolution2d_52[0][0]           
                                                                   convolution2d_53[0][0]           
____________________________________________________________________________________________________
convolution2d_55 (Convolution2D) (None, 7, 7, 192)     196800      merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_57 (Convolution2D) (None, 7, 7, 48)      49200       merge_9[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_12 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_54 (Convolution2D) (None, 7, 7, 384)     393600      merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_56 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_55[0][0]           
____________________________________________________________________________________________________
convolution2d_58 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_57[0][0]           
____________________________________________________________________________________________________
convolution2d_59 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_12[0][0]            
____________________________________________________________________________________________________
merge_10 (Merge)                 (None, 7, 7, 1024)    0           convolution2d_54[0][0]           
                                                                   convolution2d_56[0][0]           
                                                                   convolution2d_58[0][0]           
                                                                   convolution2d_59[0][0]           
____________________________________________________________________________________________________
averagepooling2d_1 (AveragePoolin(None, 1, 1, 1024)    0           merge_10[0][0]                   
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1024)          0           averagepooling2d_1[0][0]         
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 128)           131200      flatten_1[0][0]                  
____________________________________________________________________________________________________
lambda_1 (Lambda)                (None, 128)           0           dense_1[0][0]                    
====================================================================================================
Total params: 7456944
____________________________________________________________________________________________________
None

4 个答案:

答案 0 :(得分:7)

除了学习率太高之外,可能发生的事情是,有效地使用了不稳定的三元组选择策略。例如,如果你只使用'hard triplets'(距离小于ap距离的三元组),你的网络权重可能会将所有嵌入都折叠到一个点(使得损失始终相等)到边距(你的_alpha),因为所有嵌入距离都是零。)

这也可以通过使用其他类型的三元组来修复(例如'半硬三元组',其中ap小于a,但ap和an之间的距离仍然小于margin) 。所以,如果你总是检查这个......可以在这篇博文中详细解释:https://omoindrot.github.io/triplet-loss

答案 1 :(得分:5)

您是否正在限制嵌入到"在d维超球面上?#34;?尝试在嵌入式CNN出现后立即运行tf.nn.l2_normalize

问题可能在于嵌入有点像智能用户。减少损失的一种简单方法是将所有内容设置为零。 l2_normalize强制它们为单位长度。

看起来你想要在最后一次平均游泳池之后添加规范化。

答案 2 :(得分:2)

我遇到了同样的问题,我做了一些研究工作。我认为这是因为三重态丢失需要多个输入,这可能会导致网络生成这样的输出。我还没有解决问题,但您可以查看keras的问题页面了解更多详情https://github.com/keras-team/keras/issues/9498

在问题页面中,我实现了假数据集和假三元组丢失以重新减少问题,在我改变了网络的输入结构后,损失变得正常

答案 3 :(得分:0)

张量流中的损失函数需要标签列表,即整数列表。我认为您正在传递2D矩阵,即一种热编码。

尝试

import keras.backend as K
from tf.contrib.losses.metric_learning import triplet_semihard_loss

def loss(y_true, y_pred):
    y_true = K.argmax(y_true, axis = -1)
    return triplet_semihard_loss(labels=y_true, embeddings=y_pred, margin=1.)