我遇到了5个标签的多标签分类问题(例如[1 0 1 1 0]
)。因此,我希望我的模型在固定召回率,精确召回AUC或ROC AUC等指标上有所改进。
使用与我要优化的性能指标没有直接关系的损耗函数(例如binary_crossentropy
)没有意义。因此,我想使用TensorFlow的global_objectives.recall_at_precision_loss()
或类似函数作为损失函数。
我不是要实现tf.metrics
。我已经成功完成了以下任务:https://stackoverflow.com/a/50566908/3399066
我认为我的问题可以分为两个问题:
global_objectives.recall_at_precision_loss()
或类似名称? the global objectives GitHub page (same as above)上有一个名为loss_layers_example.py
的文件。但是,由于我在TF方面没有太多经验,所以我不太了解如何使用它。另外,Google搜索TensorFlow recall_at_precision_loss example
或TensorFlow Global objectives example
不会给我任何更清晰的示例。
在一个简单的TF示例中如何使用global_objectives.recall_at_precision_loss()
?
(在Keras中):model.compile(loss = ??.recall_at_precision_loss, ...)
就足够了吗?
我的感觉告诉我,由于loss_layers_example.py
中使用了全局变量,因此比这更复杂。
如何在Keras中使用类似于global_objectives.recall_at_precision_loss()
的损失函数?
答案 0 :(得分:1)
与Martino的答案类似,但将从输入中推断形状(将其设置为固定的批量大小对我不起作用)。
外部函数不是严格必需的,但是在配置损失函数时传递参数会更加自然,尤其是当包装器在外部模块中定义时。
import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss
def get_precision_at_recall_loss(target_recall):
def precision_at_recall_loss_wrapper(y_true, y_pred):
y_true = K.reshape(y_true, (-1, 1))
y_pred = K.reshape(y_pred, (-1, 1))
return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
return precision_at_recall_loss_wrapper
然后,在编译模型时:
TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))
答案 1 :(得分:0)
我设法通过以下方式使其起作用
)