在输出张量的每个元素上汇总Keras损失函数

时间:2019-08-12 13:19:54

标签: python tensorflow keras

我有一个custom_loss_function(y_true, y_predicted)。由于外部代数库的限制,仅当y_truey_predicted均为6D向量时,此函数才起作用。我的输出张量具有(batch_count, element_count, 6)的形状,我想在custom_loss_functiony_true的最终轴上逐个应用y_predicted元素,然后使用一些函数将它们聚合。卑鄙的。

我发现了tf.map_fn(...) / keras.backend.map_fn(...),但是它们仅在轴0上运行。我还尝试了在多个批次之间首先进行串联,如下所示:

def mean_custom_loss(true_output, predicted_output):
    batches = (keras.backend.concatenate(true_output, 0)), (keras.backend.concatenate(predicted_output, 0))
    return keras.backend.mean(keras.backend.map_fn(lambda x: custom_loss_function(x[1], x[0]), batches, dtype='float32'))

但是这似乎也不起作用,因为批次是由Keras内部处理的,因此无法正确处理。在那种情况下,keras.backend.concatenate会抛出 TypeError:仅当启用急切执行时,张量对象才可迭代。

我可能很想念我。有没有简单的方法可以完成工作,并在输出的最终轴上逐个元素地应用custom_loss_function

0 个答案:

没有答案