计算喀拉斯邦的哈希距离损失

时间:2018-11-07 14:41:43

标签: python tensorflow keras deep-learning loss-function

我是keras的新手,而tensorflow,我有一个项目使用keras来实现我自己的损失函数。我想像在ranking-hashing中一样应用目标函数,并以三种方式实现。我想基于三层输出来计算损失。例如a = [0,1,1,1,0],b = [0,1,1,0,0],则a和b之间的相似度为2/3,因为sum(a AND b)= 2,并且sum(A OR B)= 3

  1. 每个输出转换为哈希值{0,1} ^ N,其中N是输出的长度。
  2. 对于每对散列,我根据汉明距离计算它们之间的距离,然后取平均值

我的问题是,在第一个时期,我会得到很小的损失结果,我认为这是异常的,并且我确定我的代码中有问题,但是现在我不知道如何解决它。

这是我的损失函数代码

def hash_lambda_func(args):
   y_true, y_pred, ingrs_rep, instrs_rep, imgs_rep = args
   def my_hash(x): # transforming layer's output into hash {0,1}
      return K.cast(K.round(x), dtype='float32')
   i = my_hash(ingrs_rep)
   j = my_hash(instrs_rep)
   k = my_hash(imgs_rep)

   #tensors = tf.map_fn(lambda i: i, x)

   ab = K.any([i,j], axis=0)
   ac = K.any([i,k], axis=0)
   bc = K.any([j,k], axis=0)
   ab = tf.reduce_sum(tf.cast(ab, tf.float32))
   ac = tf.reduce_sum(tf.cast(ac, tf.float32))
   bc = tf.reduce_sum(tf.cast(bc, tf.float32))

   ab1 = K.all([i,j], axis=0)
   ac1 = K.all([i,k], axis=0)
   bc1 = K.all([j,k], axis=0)
   ab1 = tf.reduce_sum(tf.cast(ab1, tf.float32))
   ac1 = tf.reduce_sum(tf.cast(ac1, tf.float32))
   bc1 = tf.reduce_sum(tf.cast(bc1, tf.float32))

   a = (tf.reduce_mean([ab1/ab,ac1/ac,bc1/bc])*(1-(y_true-y_pred)))**2
   return K.mean(a)

用于模型中的工具损失

y_true_input = Input(shape=(num_class,))
dense_y = Dense(20)(y_true_input)
merged = concatenate([glove_dense1, instrs_dense1, img2_dense1])
densex = Dense(20)(merged)
y_pred = Activation('softmax')(densex)
my_loss = Lambda(
    hash_lambda_func, output_shape=(1,),
    name='lambda_hash')([y_pred, y_true_input, glove_dense1, instrs_dense1, img2_dense1])
model_1 = Model(inputs=[glove_input, instrs_input, img2_input1, y_true_input],
            outputs=my_loss)

编译模型

model_1.compile(loss={'lambda_hash': lambda y_true, y_pred: y_pred}, optimizer='rmsprop', metrics=['accuracy', mean_pred])

这些是错误的代码,很抱歉。感谢您对此的任何帮助或建议。

训练历史,看来我的损失功能不符合预期

Epoch 1/10
1521/2343 [==================>...........] - ETA: 23:36 - loss: 0.0141 - mean_pred: 0.9512

0 个答案:

没有答案