TensorFlow / Keras的fit()函数的class_weight参数如何工作?

时间:2019-09-14 10:31:28

标签: tensorflow keras loss-function

我使用TensorFlow 1.12和Keras进行语义分割。我使用其 PASS src/stackoverflow/57802233/index.spec.ts mapDispatchToProps ✓ t1 (11ms) ----------|----------|----------|----------|----------|-------------------| File | % Stmts | % Branch | % Funcs | % Lines | Uncovered Line #s | ----------|----------|----------|----------|----------|-------------------| All files | 100 | 100 | 100 | 100 | | index.ts | 100 | 100 | 100 | 100 | | ----------|----------|----------|----------|----------|-------------------| Test Suites: 1 passed, 1 total Tests: 1 passed, 1 total Snapshots: 0 total Time: 4.201s 参数为tf.keras.Model.fit()提供了一个权重向量(大小等于类数)。我想知道这在内部如何运作。我使用自定义损失函数(骰子损失和焦点损失等),并且权重无法与预测值或唯一热点事实相乘,然后再输入损失函数,因为这不会产生任何影响感。我的损失函数输出一个标量值,因此也不能与函数输出相乘。那么在何处以及如何准确地考虑到班级的权重?

我的自定义损失函数是:

class_weight

2 个答案:

答案 0 :(得分:1)

您可以在github中的keras源代码中引用以下代码:

    class_sample_weight = np.asarray(
        [class_weight[cls] for cls in y_classes if cls in class_weight])

    if len(class_sample_weight) != len(y_classes):
      # subtract the sets to pick all missing classes
      existing_classes = set(y_classes)
      existing_class_weight = set(class_weight.keys())
      raise ValueError(
          '`class_weight` must contain all classes in the data.'
          ' The classes %s exist in the data but not in '
          '`class_weight`.' % (existing_classes - existing_class_weight))

  if class_sample_weight is not None and sample_weight is not None:
    # Multiply weights if both are provided.
    return class_sample_weight * sample_weight

如您所见,首先将class_weight转换为一个numpy数组class_sample_weight,然后将其与sample_weight相乘。

来源:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training_utils.py

答案 1 :(得分:0)

Keras Official Docs中所述,

  

class_weight:可选的字典映射类索引(整数)   到权重(浮动)值,用于加权损失函数   (仅在培训期间)。这对于告诉模型“支付   更多关注”来自代表性不足的班级的样本。

基本上,我们提供类别权重,其中我们具有类别错误 e。意思是,训练样本并非在所有课程之间均匀分布。有些类别的样本较少,而有些类别的样本较多。

我们需要分类器对数量较少的类给予更多的关注。一种方法可能是增加低样本类的损失值。更高的损失意味着更高的优化,从而实现有效的分类。

就Keras而言,我们传递了一个dict映射类索引到其权重(损失值将乘以的因子)。让我们举个例子吧

class_weights = { 0 : 1.2 , 1 : 0.9 }

内部,类别0和1的损失值将乘以它们相应的权重值。

weighed_loss_class0 = loss0 * class_weights[0]
weighed_loss_class1 = loss1 * class_weights[1]

现在,将使用the weighed_loss_class0weighed_loss_class1进行反向传播。

请参见thisthis