在目标标签比预训练分类器少的数据上转移学习

时间:2020-08-18 13:17:20

标签: python-3.x tensorflow deep-learning transfer-learning

假设有一个预先训练的模型(base_model),已经使用大型数据集对其进行了训练,以预测7种人类情绪,例如

'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise','Neutral'

现在,为了构建学习转移的模型,我将删除“ base_model”的最后一层,冻结其权重并使其不可训练,然后添加我自己的可调整的微调层。 / p>

我想知道如何在一个较小的数据集上训练新编译的模型“ model_finetuned”,该数据集仅包含7种情绪中的3种,即

'Anger', 'Sadness', 'Surprise'

任何以Python代码形式的帮助或建议将不胜感激。预先感谢!

2 个答案:

答案 0 :(得分:1)

正如您正确解释的那样,您可以冻结预设的模型权重并进行微调,在模型的末尾添加完全连接的层。

有两种利用预训练网络的方法:特征提取和微调。

  • 特征提取:包括使用先前网络学到的表示从新样本中提取有趣的特征。然后,通过新分类器运行这些功能,该分类器从头开始进行培训。 (冷为最后一个完全连接的层)

  • 微调:在于解冻用于特征提取的冻结模型库的顶层,并共同训练模型的新添加部分。

  • >

具有预训练的vgg16的示例:

#Load pretrained vgg16 network
from torchvision.models import vgg16

num_classes = 3
pretrained_model = vgg16(pretrained=True)
pretrained_model.eval()
pretrained_model.to(device)

#Extracting the first part of the model
feature_extractor = pretrained_model.features

#Define feature classifier
feature_classifier = nn.Sequential(
nn.Linear(4*4*512,256),
nn.ReLU(),
nn.Linear(256, num_classes))

#
model = nn.Sequential(
feature_extractor,
nn.Flatten(),
feature_classifier)

如您所见,您必须在最后一个完全连接的层中指定模型的输出。您的情况是(num_classes = 3)。

答案 1 :(得分:0)

这是我几天前使用Tensorflow Keras处理的代码的示例

import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.python.keras.layers import Dense

num_classes= 3

# Include the path of the weights for the pretrained model
resnet_weights_path='imagenet'

# Create your model
model= Sequential()

# Include the pre-trained model. In this case, ResNet50
model.add(ResNet50(include_top=False,pooling='avg',weights=resnet_weights_path ))

# Add as many extra layers as you need, according to you problem
# You can also try it directly

# Add the final layer that makes predictions. Suit yourself with the activation function 
model.add(Dense(num_classes,activation='softmax'))

# Don't train the pre-trained model
model.layers[0].trainable=False

# Compile your model according to your needs
model.compile(optimizer='sgd',loss='categorical_crossentropy',metrics=['accuracy'])

现在您可以训练模型了。