在Tensorflow 2.0中实现自定义损失功能

时间:2019-08-08 13:45:22

标签: tensorflow keras loss-function tensorflow2.0

我正在建立时间序列分类模型。数据非常不平衡,所以我决定使用加权交叉熵函数作为损失。

Tensorflow提供了tf.nn.weighted_cross_entropy_with_logits,但我不确定如何在TF 2.0中使用它。因为我的模型是使用tf.keras API构建的,所以我正在考虑创建这样的自定义损失函数:

pos_weight=10
def weighted_cross_entropy_with_logits(y_true,y_pred):
  return tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,pos_weight)

# .....
model.compile(loss=weighted_cross_entropy_with_logits,optimizer="adam",metrics=["acc"])

我的问题是:有没有一种方法可以直接将tf.nn.weighted_cross_entropy_with_logits与tf.keras API一起使用?

1 个答案:

答案 0 :(得分:2)

您可以将类权重直接传递给model.fit函数。

  

class_weight:可选的将类索引(整数)映射到   权重(浮动)值,用于对损失函数(在   仅培训)。告诉模型“多付钱”可能很有用   注意”来自代表性不足的班级的样本。

例如:

{
    0: 0.31, 
    1: 0.33, 
    2: 0.36, 
    3: 0.42, 
    4: 0.48
}

Source