keras中的model.compile参数'weighted_metrics'和model.fit_generator参数'class_weight'之间的区别?

时间:2019-07-03 10:34:42

标签: python keras deep-learning image-recognition

在训练用于图像分类的keras模型(来自DOG BREED IDENTIFICATION数据集,KAGGLE的120个类)时,我需要使用在某处读取的类权重来平衡这些类,在示例中,我已经看到人们使用fit_generator的参数class_weight。但是我在model.compile中找到了另一个参数weighted_metrics,其在文档中的描述是:“在训练和测试期间要通过sample_weight或class_weight评估和加权的指标列表”。我要用这个吗?请通过任何示例说明此参数的用途。

#Calculating Class weights
counter = Counter(train_generator.classes)
max_value = float(max(counter.values()))

CLASS_WEIGHTS = {classid: max_value / num_occurences
                 for classid, num_occurences in counter.items()}
# Model Compile
model.compile(optimizer=Adam(lr=LR),
              loss=categorical_crossentropy,
              metrics=[categorical_accuracy],
              weighted_metrics=None) # <--------------- This parameter

STEPS_PER_EPOCH = train_generator.n//train_generator.batch_size
VAL_STEPS = val_generator.n//val_generator.batch_size

model.fit_generator(train_generator,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    callbacks=callback_list,
                    verbose=1,
                    class_weight=CLASS_WEIGHTS,
                    validation_data=val_generator,
                    validation_steps=VAL_STEPS) # USED CLASS_WEIGHTS HERE

1 个答案:

答案 0 :(得分:0)

是的,您可以将它们用于不平衡的数据集。

  

weighted_metrics

是考虑到

的指标列表
  

class_weights

您传入了fit_generator。

因此在您的示例中,您可以设置

weighted_metrics=['accuracy']

class_weight = {0 : 3, 1: 4}

weighted_metrics参数的目的是给出一系列指标,这些指标将考虑在fit_generator中传递的class_weights。