Keras分类澄清

时间:2020-08-07 05:00:02

标签: python machine-learning keras deep-learning conv-neural-network

我有一个图像数据集和对象的边界框坐标,我从数据集中选择5个对象进行分类。为了进行训练,我使用边界框坐标裁剪了每个图像中的每个对象。然后,进行一次热编码(对象名称-['A','B','C','D','E']),进行训练测试拆分,最后进行训练。我的验证准确性达到了99%。当我使用裁剪后的图像测试模型时,该图像每个图像仅包含一个对象(无背景),该模型可以完美地对对象进行分类。

但是,当我使用包含所有5个对象的图像测试模型时,预测并不准确。显然,该模型返回所有​​对象的预测概率,但是该模型始终仅对第一个对象(对象A)给出高概率。像这样-[[1.00000000e+00 2.64929882e-18 4.15273056e-17 1.11363124e-26 4.15807750e-22]]

我不明白,测试图像包含所有5个对象,我认为该模型对所有对象给出高概率,而不是仅对一个对象(对象A)给出高概率。您能帮我理解这个问题吗?如果给出包含所有5个对象的输入图像,应该怎么做才能使所有对象具有较高的概率。

代码-

标签处理-

le = sklearn.preprocessing.LabelEncoder()
y = le.fit_transform(labels)
y = keras.utils.np_utils.to_categorical(y,5)

变量'labels'是对象名称('A','B','C','D','E')的numpy数组

模型-

  model = Sequential()

  model.add(Conv2D(32, (3, 3), padding='same', input_shape=(224, 224, 3), activation="relu"))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(filters = 64, kernel_size = (3,3),padding = 'Same',activation ='relu'))
  model.add(MaxPooling2D(pool_size=(2,2)))

  model.add(Conv2D(filters = 96, kernel_size = (3,3),padding = 'Same',activation ='relu'))
  model.add(MaxPooling2D(pool_size=(2,2)))

  model.add(Conv2D(filters = 128, kernel_size = (3,3),padding = 'Same',activation ='relu'))
  model.add(MaxPooling2D(pool_size=(2,2)))

  model.add(Conv2D(filters = 256, kernel_size = (3,3),padding = 'Same',activation ='relu'))
  model.add(MaxPooling2D(pool_size=(2,2)))

  model.add(Flatten())
  model.add(Dense(512, activation="relu"))
  model.add(Dropout(0.2))
  model.add(Dense(256, activation="relu"))
  model.add(Dropout(0.5))
  model.add(Dense(5, activation="softmax"))

  model.compile(
  loss='categorical_crossentropy',
  optimizer= Adam(),
  metrics=['accuracy']
  )

  model.summary()

  History = model.fit(x_train, y_train , epochs=30, verbose = 1,
              validation_data=(x_test, y_test), 
              batch_size = 128, 
              shuffle=True,
              )

metrics.classification_report()的输出-

                    precision    recall  f1-score   support

                0       1.00      1.00      1.00       246
                1       1.00      1.00      1.00       284
                2       1.00      1.00      1.00       266
                3       1.00      1.00      1.00       284
                4       0.99      1.00      1.00       241

          accuracy                           1.00      1804
        macro avg       1.00      1.00      1.00      1804
      weighted avg       1.00      1.00      1.00      1804

0 个答案:

没有答案