训练后如何使用带GGAS的VGG16预测图像(外部数据集)?

时间:2019-12-14 00:32:53

标签: python tensorflow keras vgg-net

我已经使用自己的数据集使用keras训练了VGG16网络,该数据集有10个类。所以我用10个类修改了激活层。

这是代码

TRAIN_DIR = "D:\\Dataset\\training"   
VALIDATION_DIR = "D:\\Dataset\\validation"

第2部分

   from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300
BATCH_SIZE = 16

第3部分

train_datagen = ImageDataGenerator(rescale=1./255,      
                                    rotation_range=40,      
                                    width_shift_range=0.2,
                                    height_shift_range=0.2,
                                    shear_range=0.2,
                                    zoom_range=0.2,
                                    horizontal_flip=True, 
                                    fill_mode='nearest'
                                  )

validation_datagen = ImageDataGenerator(rescale=1./255, 
                                  )

train_generator = train_datagen.flow_from_directory(TRAIN_DIR, 
                                                    target_size=(IMAGE_WIDTH, IMAGE_HEIGHT), 
                                                    batch_size = BATCH_SIZE, 
                                                    shuffle=True, # By shuffling the images we add some randomness and prevent overfitting
                                                    class_mode="categorical")

validation_generator = validation_datagen.flow_from_directory(VALIDATION_DIR, 
                                                    target_size=(IMAGE_WIDTH, IMAGE_HEIGHT), 
                                                    batch_size = BATCH_SIZE, 
                                                    shuffle=True,
                                                    class_mode="categorical")

第4部分

training_samples = 1097
validation_samples = 272
total_steps = training_samples // BATCH_SIZE

加载VGG16

#VGG16 network with pretrained weights is used

from keras.applications import vgg16
model = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(IMAGE_WIDTH, IMAGE_HEIGHT, 3), pooling="max")

for layer in model.layers[:-5]:
        layer.trainable = False

for layer in model.layers:
    print(layer, layer.trainable)

第5部分

from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras.models import Model, Sequential

# Although this part can be done also with the functional API, I found that for this simple models, this becomes more intuitive
transfer_model = Sequential()
for layer in model.layers:
    transfer_model.add(layer)
transfer_model.add(Dense(512, activation="relu")) 
transfer_model.add(Dropout(0.5))
transfer_model.add(Dense(10, activation="softmax")) 

第6部分

# Adam optimizer and learning rate 0.0001

from keras import optimizers
adam = optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.00001)

transfer_model.compile(loss="categorical_crossentropy",
                      optimizer=adam,
                      metrics=["accuracy"])

最终培训

model_history = transfer_model.fit_generator(train_generator, steps_per_epoch=training_samples // BATCH_SIZE,
                                            epochs=25,
                                            validation_data=validation_generator,
                                            validation_steps=validation_samples // BATCH_SIZE)

第7部分,使用互联网上的一些随机图像进行预测

test_path = "D:\\Dataset\\predict\\"
test_datagen = ImageDataGenerator(rescale=1./255,      
                                    rotation_range=40,      
                                    width_shift_range=0.2,
                                    height_shift_range=0.2,
                                    shear_range=0.2,
                                    zoom_range=0.2,
                                    horizontal_flip=True, 
                                    fill_mode='nearest'
                                  )



test_generator = test_datagen.flow_from_directory(test_path, 
                                                    target_size=(IMAGE_WIDTH, IMAGE_HEIGHT), 
                                                    batch_size = 50, 
                                                    class_mode="categorical")
enter code here

在这一部分中,我试图进行预测,但是得到这种数字却没有获得我想要作为图像的实际预测结果

pred = model.predict_generator(test_generator, steps=1)
print(pred)

结果是这样的,但我希望这些是真实的图像,但无法弄清楚如何。

1 个答案:

答案 0 :(得分:0)

您不能从网络中输出图像,也不清楚您如何想象它会如何工作-图像是输入,输出是一个数字列表,每个类一个值。您可以将这些数字解释为图像包含该类对象的概率。

然后,您可以找到最可能的类(例如,使用argmax函数),并根据需要显示该类的图像-但这必须单独完成。

注意-您正在使用原始模型而不是transfer_model运行预测:

pred = model.predict_generator(test_generator, steps=1)

您应该使用训练有素的转移模型来获得班级预测,这将采用包含10个概率的向量的形式,每个班级一个值。