Keras分类模型返回0%的训练准确度

时间:2020-06-24 18:42:42

标签: python tensorflow keras classification vgg-net

一些细节:

任务:使用VGG16对属于10个不相交类的图像进行分类。

问题:我的训练准确性一直都为0%。

代码

主文件:

##################################################################################################################################################
# PREPARE THE DATASET
##################################################################################################################################################
# create a data generator
datagen = tf.keras.preprocessing.image.ImageDataGenerator()

# load and iterate train dataset
train_it = datagen.flow_from_directory(DATASET + "/train", class_mode="categorical", batch_size=BATCH_SIZE, target_size=(IMAGE_SIZE, IMAGE_SIZE))

# load and iterate test dataset
test_it = datagen.flow_from_directory(DATASET + "/test", class_mode="categorical", batch_size=BATCH_SIZE, target_size=(IMAGE_SIZE, IMAGE_SIZE))

#####################################################################################################################################################################################
# PREPARE THE MODEL 
#####################################################################################################################################################################################
model = get_model(MODEL,WEIGHTS)

opt = tf.keras.optimizers.SGD(learning_rate=LR, name="SGD")

# compile the model
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=[tf.keras.metrics.Accuracy(), tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])

#######################################################################
# TRAIN THE MODEL
#######################################################################
# fit the model
model.fit(train_it, epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True)

模型文件:

def vgg16(weights):

    model = keras.applications.VGG16(
            include_top=True,
            weights=weights,
            classes=10
        )

    return(model)

根据我所做的研究,人们似乎在使用分类模型进行线性回归时遇到了这个问题。我看不出如何将我的代码与回归任务混在一起。

欢迎任何帮助!预先感谢。

0 个答案:

没有答案