将TensorFlow损失全局目标(recall_at_precision_loss)与Keras(而非指标)结合使用

时间:2019-01-21 08:57:00

标签: tensorflow keras loss-function auc precision-recall

背景

我遇到了5个标签的多标签分类问题(例如[1 0 1 1 0])。因此,我希望我的模型在固定召回率,精确召回AUC或ROC AUC等指标上有所改进。

使用与我要优化的性能指标没有直接关系的损耗函数(例如binary_crossentropy)没有意义。因此,我想使用TensorFlow的global_objectives.recall_at_precision_loss()或类似函数作为损失函数。

非公制

我不是要实现tf.metrics。我已经成功完成了以下任务:https://stackoverflow.com/a/50566908/3399066

问题

我认为我的问题可以分为两个问题:

  1. 如何使用global_objectives.recall_at_precision_loss()或类似名称?
  2. 如何在带有TF后端的Keras模型中使用它?

问题1

the global objectives GitHub page (same as above)上有一个名为loss_layers_example.py的文件。但是,由于我在TF方面没有太多经验,所以我不太了解如何使用它。另外,Google搜索TensorFlow recall_at_precision_loss exampleTensorFlow Global objectives example不会给我任何更清晰的示例。

在一个简单的TF示例中如何使用global_objectives.recall_at_precision_loss()

问题2

(在Keras中):model.compile(loss = ??.recall_at_precision_loss, ...)就足够了吗? 我的感觉告诉我,由于loss_layers_example.py中使用了全局变量,因此比这更复杂。

如何在Keras中使用类似于global_objectives.recall_at_precision_loss()的损失函数?

2 个答案:

答案 0 :(得分:1)

与Martino的答案类似,但将从输入中推断形状(将其设置为固定的批量大小对我不起作用)。

外部函数不是严格必需的,但是在配置损失函数时传递参数会更加自然,尤其是当包装器在外部模块中定义时。

import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss

def get_precision_at_recall_loss(target_recall): 
    def precision_at_recall_loss_wrapper(y_true, y_pred):
        y_true = K.reshape(y_true, (-1, 1)) 
        y_pred = K.reshape(y_pred, (-1, 1))   
        return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
    return precision_at_recall_loss_wrapper

然后,在编译模型时:

TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))

答案 1 :(得分:0)

我设法通过以下方式使其起作用

  • 将张量显式调整为BATCH_SIZE长度(请参见下面的代码)
  • 将数据集大小切割为BATCH_SIZE的倍数
)