以ImageNet上的图像分类为例,如何使用新数据点更新预训练模型。 我已经加载了预训练模型。我有一个新的数据点,该数据点与之前训练模型的原始数据的分布完全不同。因此,我想借助新的数据点来更新/微调模型。如何去做呢?有人可以帮我吗?我正在使用pytorch 0.4.0来实现,并在GPU Tesla K40C上运行。
答案 0 :(得分:1)
如果您不想更改分类器的输出(即类别数),则可以简单地继续使用新的示例图像训练模型,并假设它们已重塑为与预训练模型相同的形状接受。
另一方面,如果要更改预训练模型中的类数,则可以用新的替换最后一个完全连接的层,并仅在新样本上训练该特定层。这是来自PyTorch's autograd mechanics notes的这种情况的示例代码:
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)