Keras多标签分类

时间:2019-10-15 12:54:19

标签: tensorflow keras multiclass-classification

我刚刚遵循了有关Keras https://www.pyimagesearch.com/2018/05/07/multi-label-classification-with-keras/的多标签分类的教程。

模型结构:

model = Sequential()
inputShape = (height, width, depth)
chanDim = -1

model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(3, 3)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Dense(classes))
model.add(Activation(finalAct))

现在,我将使用此方法进行香蕉成熟度等级分类。它具有四个类(未成熟香蕉,成熟香蕉,未成熟香蕉和not_banana)。我为每个班级训练了600张数据集图像。训练过程以30个纪元完成,准确度约为94%。但是,当我使用分类过程进行测试时,我得到了一些误报。请问您对此案有何建议/建议?

Output

GIF的翻译:

Bukan pisang -> not_banana      
Pisang mentah -> unripe banana     
Pisang matang -> ripe banana      
Pisang terlalu matang -> overripe banana

谢谢您的时间。

0 个答案:

没有答案