keras - 获得每个班级的概率

时间:2018-01-11 22:54:34

标签: python deep-learning keras

我试图从keras模型中获得每个类的概率。请在下面找到样本keras模型:

train_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')


train_generator = train_datagen.flow_from_directory(
        '.\\train',  # this is the target directory
        target_size=(width, height),  # all images will be resized to 150x150
        batch_size=batch_size,
        class_mode='binary',
        shuffle=True)  # since we use binary_crossentropy loss, we need binary labels

# this is a similar generator, for validation data
validation_generator = test_datagen.flow_from_directory(
        '.\\validate',
        target_size=(width, height),
        batch_size=batch_size,
        class_mode='binary',
        shuffle=True)

model.fit_generator(
        train_generator,
        steps_per_epoch=4000,
        epochs=2,
        validation_data=validation_generator,
        validation_steps=1600)

然而,在训练模型后,我通过以下方式加载要预测的图像:

{{1}}

我仍然得到类标签,而不是概率。我有什么提示我做错了吗?

修改 这就是模型的训练方式:

{{1}}

1 个答案:

答案 0 :(得分:1)

问题是您在'sparse_categorical_crossentropy' class_mode='binary'使用了ImageDataGenerator 'categorical_crossentropy'。{/ p>

这里有两种可能性:

  1. 将损失更改为class_mode='categorical'并设置class_mode='sparse'
  2. 保留原来的损失,但设置[0, 1, 0, 0]
  3. 要么工作。

    有关两种损失之间的差异,请参阅this answer(在Tensorflow中,但也适用于Keras)。简短版本是稀疏损失期望标签是整数类(例如1,2,3 ......),而普通版本需要单热编码矢量(例如ImageDataGenerator)。

    干杯

    编辑:正如@Simeon Kredatus指出的那样,这是一个正常化问题。 通过在samplewise_center=True构造函数中为训练集和测试集设置适当的标志,即samplewise_std_normalization=Truefunction myTest($x = 5) { echo "<p>Variable x inside function is: $x</p>"; } myTest(); //default x = 5 myTest(6) // will print x = 6 ,可以轻松解决这个问题。 更新答案,以便人们可以看到解决方案。一般来说,请记住垃圾桶中的垃圾原则。