keras训练CNN后预测眼周图像的类别

时间:2020-01-16 12:31:34

标签: python keras deep-learning prediction multiclass-classification

培训

id, StateID, Year, Population
1, 1,1, 559330
2, 1,2, 567020
3, 2,1, 347192
4, 2,2, 351932

预测

        import keras  
        import numpy as np  
        import matplotlib.pyplot as plt





        from keras.preprocessing.image import ImageDataGenerator  




        datagen= ImageDataGenerator(rotation_range=40,width_shift_range=0.2             
        ,height_shift_range=0.2,zoom_range=0.2,rescale=1./255.)





        type(datagen)




        from keras.models import Sequential  
        from keras.layers import Conv2D,MaxPool2D,Flatten,Dense,Activation  
        from keras.activations import relu , softmax  
        from keras.losses import categorical_crossentropy  
        from keras.optimizers import SGD,RMSprop  

        from keras.callbacks import TensorBoard  




        model=Sequential()  

        model.add(Conv2D(32,(3,3),input_shape=(150,150,3),activation="relu"))  
        model.add(MaxPool2D(pool_size=(2,2)))  

        model.add(Conv2D(32,(3,3),activation="relu"))  
        model.add(MaxPool2D(pool_size=(2,2)))  

        model.add(Conv2D(64,(3,3),activation="relu"))  
        model.add(MaxPool2D(pool_size=(2,2)))  


        model.add(Flatten())  

        model.add(Dense(1024,activation="relu"))  
        model.add(Dense(512,activation="relu"))  
        model.add(Dense(512,activation="relu"))  
        model.add(Dense(512,activation="relu"))  
        model.add(Dense(512,activation="relu"))  
        model.add(Dense(512,activation="relu"))  
        model.add(Dense(5,activation="softmax"))  





        model.compile(loss="categorical_crossentropy" , optimizer=SGD(),metrics=["acc"])  




        train_gen=datagen.flow_from_directory("/home/vishu//Desktop/basics/dataset",target_size=    
        (150,150),batch_size=100)  




        tb=TensorBoard(log_dir=".")  




        model_history=model.fit_generator(train_gen,epochs=2)  

使用它之后,它总是给我输出4
我应该如何预测正确的图像类别?
在这里,我从文件夹中获取输入图像 我为5个班级创建了5个文件夹,那么我应该如何预测图像的班级呢?

1 个答案:

答案 0 :(得分:0)

您忘记了在ImageDataGenerator中执行的重新缩放(除以255),这需要使用新的测试数据来完成,因此必须在<input type="hidden" id="temp_min_field" name="temp_min_field" value="<?php echo $row['temp_min']; ?>"> <input type="hidden" id="temp_max_field" name="temp_max_field" value="<?php echo $row['temp_max']; ?>"> 函数中执行。