我通过预训练的迁移学习(VGG16)获得了0.69的准确度。我想使用Resnet50编写相同的代码,但我惊讶的是我的精度较低(0.1) 我尝试了很多版本的代码(将API与resnet结合使用,训练我的模型,然后加载权重...以使用resnet50加入模型),但我的准确性始终很低。下面是我使用VGG16编写的代码以及我使用resnet编写的一种代码版本。 您能解决第二个代码吗?
转移学习(VGG16),准确度为0.69
```
#Import du modele VGG16 pre-entrainer sur le dataset inagenet
imagenet_model2 = tf.keras.applications.VGG16(weights = "imagenet",
include_top=False,
input_shape = (150, 150, 3),
pooling='max')
#Couches du modele VGG16
imagenet_model2.layers
for l in imagenet_model2.layers:
l.trainable = False
#construction du model
model4 = [
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(12, activation='softmax')
]
#Jointure des deux modeles
model_using_pre_trained_one2 = tf.keras.Sequential( imagenet_model2.layers + model4 )
#choix de l'algo d'apprentissage
model_using_pre_trained_one2.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
#Apprentissage du modele
model_using_pre_trained_one2.fit(trainx2, trainy2, epochs=7,steps_per_epoch=len(trainx2))
#evaluation du modele
test_loss4, test_acc4 = model_using_pre_trained_one2.evaluate(testx2, testy2, verbose=2)
print('\nTest accuracy:', test_acc4)
```
** RESNET50低精度0.1 :( **
```
resnet50_imagenet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(150, 150, 3),
pooling='avg')
resnet50_imagenet_model.output
x = resnet50_imagenet_model.output
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dense(12, activation='softmax')(x)
model = Model(inputs = resnet50_imagenet_model.input, outputs = x)
count=0
#Putting the first 176 of resnet50 layers as trainable false
for l in resnet50_imagenet_model.layers:
count=count+1
if count <=176:
l.trainable = False
#choix de l'algo d'apprentissage
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
#Apprentissage du modele
model.fit(trainx2, trainy2, epochs=7,steps_per_epoch=len(trainx2))
#evaluation du modele
test_loss3, test_acc3 = model.evaluate(testx2, testy2, verbose=2)
print('\nTest accuracy:', test_acc3)
print('\nTest loss:', test_loss3)