使用tf.GradientTape的预训练模型进行的转移学习无法收敛

时间:2019-07-30 09:41:17

标签: python keras tensorflow2.0 transfer-learning

我想用预训练的喀拉斯模型进行迁移学习

import tensorflow as tf
from tensorflow import keras

base_model = keras.applications.MobileNetV2(input_shape=(96, 96, 3), include_top=False, pooling='avg')
x = base_model.outputs[0]
outputs = layers.Dense(10, activation=tf.nn.softmax)(x)

model = keras.Model(inputs=base_model.inputs, outputs=outputs)

使用keras编译/拟合功能进行培训可以收敛

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])

history = model.fit(train_data, epochs=1)

结果是:损失:0.4402-准确性:0.8548

我想用tf.GradientTape训练,但它不能收敛

optimizer = keras.optimizers.Adam()
train_loss = keras.metrics.Mean()
train_acc = keras.metrics.SparseCategoricalAccuracy()
def train_step(data, labels):    
    with tf.GradientTape() as gt:
        pred = model(data)
        loss = keras.losses.SparseCategoricalCrossentropy()(labels, pred)

    grads = gt.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    train_loss(loss)
    train_acc(labels, pred)

for xs, ys in train_data:
    train_step(xs, ys)

print('train_loss = {:.3f}, train_acc = {:.3f}'.format(train_loss.result(), train_acc.result()))

但是结果是:train_loss = 7.576,train_acc = 0.101

如果我仅通过设置来训练最后一层

base_model.trainable = False

它收敛,结果是:train_loss = 0.525,train_acc = 0.823

代码有什么问题?我应该如何修改?谢谢

2 个答案:

答案 0 :(得分:1)

尝试将RELU用作激活功能。如果您使用除RELU之外的其他激活功能,则可能是消失梯度问题。

答案 1 :(得分:1)

在我的评论之后,它之所以未能收敛,是因为您选择的学习率太大。这导致重量变化太大,损失爆炸。将base_model.trainable设置为False时,网络中的大多数权重都是固定的,学习率非常适合您的最后一层。这是一张照片: enter image description here

通常,每次实验都应选择学习率。

编辑:按照威尔逊的评论,我不确定这是您得出不同结果的原因,但是可能是这样:

当您指定损失时,您的损失是按批次中的每个元素计算的,然后要获得批次的损失,您可以取损失的总和或平均值,具体取决于您选择的损失不同的幅度。例如,如果您的批次大小为64,则对损失进行求和将使您产生64倍的损失,这将产生64倍的梯度,因此,选择批次大小为64的平均数之和就好比选择64倍的学习率。 因此,可能您得到不同结果的原因是,默认情况下,包装在keras.losses中的model.compile具有不同的归约方法。同样,如果通过求和方法减少了损失,则损失的大小取决于批量大小,如果批次大小是两倍,则(平均)损失是两倍,梯度是两倍,所以就像将学习速度提高一倍。

我的建议是检查损失所使用的减少方法,以确保两种情况下的损失方法均相同,如果是总和,则检查批次大小是否相同。我建议一般使用均值缩减,因为它不受批量大小的影响。