我正在尝试在Keras中进行转学。我将ResNet50网络设置为无法通过一些额外的层进行训练:
# Image input
model = Sequential()
model.add(ResNet50(include_top=False, pooling='avg')) # output is 2048
model.add(Dropout(0.05))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.15))
model.add(Dense(512, activation='relu'))
model.add(Dense(7, activation='softmax'))
model.layers[0].trainable = False
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
然后,我使用ResNet50 x_batch
函数创建输入数据:preprocess_input
以及一个热编码标签y_batch
并进行拟合:
model.fit(x_batch,
y_batch,
epochs=nb_epochs,
batch_size=64,
shuffle=True,
validation_split=0.2,
callbacks=[lrate])
十个左右的训练时间后,训练准确性接近100%,但验证准确性实际上从50%下降到30%,而验证损失却在稳步增加。
但是,如果我改为只创建最后一层的网络:
# Vector input
model2 = Sequential()
model2.add(Dropout(0.05, input_shape=(2048,)))
model2.add(Dense(512, activation='relu'))
model2.add(Dropout(0.15))
model2.add(Dense(512, activation='relu'))
model2.add(Dense(7, activation='softmax'))
model2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model2.summary()
并输入ResNet50预测的输出:
resnet = ResNet50(include_top=False, pooling='avg')
x_batch = resnet.predict(x_batch)
然后,验证准确率将达到85%左右...怎么回事?为什么图像输入法不起作用?
更新:
这个问题确实很奇怪。如果我将ResNet50更改为VGG19,它似乎可以正常工作。
答案 0 :(得分:1)
经过大量的搜索后,我发现问题出在ResNet中的Batch Normalization层。 VGGNet中没有批处理规范化层,因此它适用于该拓扑。
在Keras here中有一个修复请求来解决此问题,该请求有更详细的解释:
假设我们使用Keras的一种预训练的CNN,并且希望对其进行微调。不幸的是,我们无法保证BN层中新数据集的均值和方差将与原始数据集的均值和方差相似。结果,如果我们微调顶层,则它们的权重将被调整为新数据集的均值/方差。但是,在推理过程中,顶层将接收使用原始数据集的均值/方差缩放的数据。这种差异会导致准确性降低。
这意味着BN层正在调整训练数据,但是在执行验证时,将使用BN层的原始参数。据我所知,解决方法是允许冻结的BN层使用训练后更新的均值和方差。
一种解决方法是预先计算ResNet输出。实际上,这将大大减少训练时间,因为我们没有重复计算的那一部分。
答案 1 :(得分:0)
您可以尝试:
Res = keras.applications.resnet.ResNet50(include_top=False,
weights='imagenet', input_shape=(IMG_SIZE , IMG_SIZE , 3 ) )
# Freeze the layers except the last 4 layers
for layer in vgg_conv.layers :
layer.trainable = False
# Check the trainable status of the individual layers
for layer in vgg_conv.layers:
print(layer, layer.trainable)
# Vector input
model2 = Sequential()
model2.add(Res)
model2.add(Flatten())
model2.add(Dropout(0.05 ))
model2.add(Dense(512, activation='relu'))
model2.add(Dropout(0.15))
model2.add(Dense(512, activation='relu'))
model2.add(Dense(7, activation='softmax'))
model2.compile(optimizer='adam', loss='categorical_crossentropy', metrics =(['accuracy'])
model2.summary()