图像分类 CNN 模型总是预测相同的值

时间:2021-06-21 19:09:53

标签: tensorflow machine-learning keras deep-learning image-classification

我有一个图像数据集,其结构如下:

money_photo/
           100/
           50/
           10/
           1/

每个目录内有 240 张照片,对应钞票的价值(100、50、10 和 1)。

我正在使用 keras.preprocessing.image_dataset_from_directory 拆分 train 和 val 数据集,如下几行:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

找到属于 4 个类的 960 个文件。 使用 768 个文件进行训练。

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

找到属于 4 个类的 960 个文件。 使用 192 个文件进行验证。

每张图像都被调用到 180x180 像素并对其进行标准化(0..255 像素值在 0<=value<=1 之间具有相应的值)

模型定义如下:

num_classes = 4

model = tf.keras.Sequential([
  layers.experimental.preprocessing.Rescaling(1./255),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes),
  layers.Activation('softmax')
])

训练后我得到以下结果:

时代 3/3 24/24 [==============================] - 10 秒 425 毫秒/步 - 损失:0.3214 - 准确度:0.8866 - val_loss : 0.2449 - val_accuracy: 0.9115

我使用模型进行预测的方式:

import tensorflow as tf
from PIL import Image
import numpy as np
from skimage import transform

def load(filename):
    np_image = Image.open(filename)
    np_image = np.array(np_image).astype('float32')/255
    np_image = transform.resize(np_image, (180, 180, 3))
    np_image = np.expand_dims(np_image, axis=0)
    return np_image

image = load('abd.jpg')
prediction = model.predict(image)

print(class_names[np.argmax(prediction)])

为什么我总是得到相同的预测值?

1 个答案:

答案 0 :(得分:0)

您的模型内置了重新缩放层,因此您不应重新缩放输入图像。改变一下

np_image = np.array(np_image).astype('float32')/255

np_image = np.array(np_image).astype('float32')