我有一个网络使用自定义损失功能来掩盖具有相同标签的输入,在训练准确性/损失不超过1%时。我不确定是什么原因造成的,我已经检查了输入数据,没有问题。
# Custom loss to take full batch (of size beam) and apply a mask to calculate the true loss within the beam
beam_size = 10
def create_mask(y, yhat):
idxs = list(permutations(range(beam_size), r=2))
perms_y = tf.squeeze(tf.gather(y, idxs))
perms_yhat = tf.squeeze(tf.gather(yhat, idxs))
mask = tf.where(tf.not_equal(perms_y[:,0], perms_y[:,1]))
mask = tf.reduce_sum(mask, 1)
uneq = tf.boolean_mask(perms_y, mask, axis=0)
yhat_uneq = tf.boolean_mask(perms_yhat, mask, axis=0)
return uneq, yhat_uneq
def mask_acc(y, yhat):
uneq, yhat_uneq = create_mask(y, yhat)
uneq = tf.argmax(uneq,1)
yhat_uneq = tf.argmax(yhat_uneq, 1)
# argmax and compare
return tf.cond(tf.greater(tf.size(yhat_uneq), 1), lambda: tf.reduce_sum(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32)), lambda: 0.)
def mask_loss(y, yhat):
# Cosider weighted loss
uneq, yhat_uneq = create_mask(y, yhat)
#uneq = tf.argmax(uneq,1)
#create all permutations and zero out matches with mask
total_loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=tf.cast(uneq, tf.int32), logits=yhat_uneq))
#d = tf.Print(yhat_uneq, [yhat_uneq], summarize=-1)
return total_loss
x = Input(shape=(72,300))
aux_input = Input(shape=(72, 3))
probs = Input(shape=(1,))
#dim_red_1 = Dense(100)(x)
dim_red_2 = Dense(25, activation='tanh')(x)
cat = concatenate([dim_red_2, aux_input])
encoded = LSTM(5)(cat)
output = Lambda(lambda x: K.sum(x, axis=1))(encoded)
#cat2 = concatenate([encoded, probs])
#output = Dense(1, activation='linear')(cat2)
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=.8, nesterov=True)
lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer=sgd, loss=mask_loss, metrics=[mask_acc])
奇怪的是,将输出激活设置为softmax可以大大提高准确度,但是tf.losses.softmax_cross_entropy期望未标准化的logit不确定为什么会这样。