简单的tf.keras Resnet50模型无法收敛

时间:2020-06-29 22:41:07

标签: python tensorflow keras deep-learning neural-network

我正在使用keras.applications中的ResNet50v2模型进行图像分类,但是在尝试使模型收敛于任何有意义的精度时,我一直遇到问题。以前,我在Matlab中使用相同的数据开发了相同的模型,并达到了约75%的准确度,但是现在训练仅徘徊在约30%的准确度上,损失不会减少。我认为某个地方确实存在一个非常简单的错误,但我找不到它。

import tensorflow as tf

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./224,
    validation_split=0.2)

train_generator = train_datagen.flow_from_directory(main_dir,
                                                    class_mode='categorical',
                                                    batch_size=32,
                                                    target_size=(224,224),
                                                    shuffle=True,
                                                    subset='training')

validation_generator = train_datagen.flow_from_directory(main_dir,
                                                        target_size=(224, 224),
                                                        batch_size=32,
                                                        class_mode='categorical',
                                                        shuffle=True,
                                                        subset='validation')

IMG_SHAPE = (224, 224, 3)

base_model = tf.keras.applications.ResNet50V2(
    input_shape=IMG_SHAPE,
    include_top=False,
    weights='imagenet')

maxpool_layer = tf.keras.layers.GlobalMaxPooling2D()
prediction_layer = tf.keras.layers.Dense(4, activation='softmax')

model = tf.keras.Sequential([
    base_model,
    maxpool_layer,
    prediction_layer
])

opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_generator,
    steps_per_epoch = train_generator.samples // 32,
    validation_data = validation_generator,
    validation_steps = validation_generator.samples // 32,
    epochs = 20)

1 个答案:

答案 0 :(得分:0)

由于您的最后一层包含softmax激活,因此您的损失不需要from_logits=True。但是,如果您没有softmax激活,则需要from_logits=True。这是因为categorical_crossentropy处理概率输出的方式与对数不同。