Tensorflow预训练的CNN:预测图像的相同类别

时间:2020-10-16 00:50:14

标签: python tensorflow deep-learning cnn

我正在使用MobilenetV2预先训练的模型对具有四类(新鲜植物,患病植物,新鲜叶子,患病叶子)的植物图像进行分类。

我尝试使用以下代码集进行目录方法的扩充和流处理,确保已为val_generator禁用shuffle = False

    train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

#Rescaling or normalising the pixels in the image
test_datagen = ImageDataGenerator(
        rescale=1./255)

val_datagen = ImageDataGenerator(
        rescale=1./255)

#Using the flow from directory method to read the images directly from the google drive
train_generator = train_datagen.flow_from_directory(
    directory='/content/drive/My Drive/data/train',
    target_size=(224, 224),
    batch_size=32,
    class_mode="categorical"
)

val_generator = val_datagen.flow_from_directory(
    directory='/content/drive/My Drive/data/val',
    target_size=(224, 224),
    batch_size=32,
    class_mode="categorical",shuffle = False
)
test_generator = test_datagen.flow_from_directory(
    directory='/content/drive/My Drive/data/test',
    target_size=(224, 224),
    batch_size=32,
    class_mode="categorical", shuffle = False
)

当我尝试根据测试数据为随机图像预测类别时,就会出现问题,

img = image.load_img('/content/drive/My Drive/data/test/diseased cotton leaf/dis_leaf (248).jpg',target_size=(224,224))
img = image.img_to_array(img)/255

img.shape #(224, 224, 3)
x=np.expand_dims(img,axis=0)
img_data=preprocess_input(x)
img_data.shape#(1, 224, 224, 3)
pred = model1.predict_generator(img_data)
pred #array([[0.28619877, 0.00713898, 0.3015607 , 0.40510157]], dtype=float32)
np.argmax(pred,axis=1)

针对不同的图像路径返回相同的类。我该如何解决这个问题?

2 个答案:

答案 0 :(得分:1)

您的预处理程序可能会执行/ 255(很有可能)。尝试将其删除或提供未经修改的图像。

答案 1 :(得分:1)

您需要在火车,验证和测试生成器中使用MobileNet预处理图像功能。此功能将像素值缩放到-1和+1之间。删除生成器中的rescale参数。参见下面的代码,了解火车发电机

train_gen=ImageDataGenerator(preprocessing_function=keras.applications.mobilenet.preprocess_input, etc.......

如果输入图像进行预测,请确保使用该功能以相同的方式进行预处理。如果您不想使用mobilenet预处理功能,则可以编写自己的函数来执行相同的操作。参见下面的代码

def pre_processor (img):
    img=img/127.5-1
    return img
train_gen=ImageDataGenerator(preprocessing_function=pre_processor etc......

对您在进行预测之前阅读的图像使用相同的功能