针对新数据点更新经过预先训练的深度学习模型

时间:2018-12-05 03:31:32

标签: python deep-learning classification conv-neural-network pytorch

以ImageNet上的图像分类为例,如何使用新数据点更新预训练模型。 我已经加载了预训练模型。我有一个新的数据点,该数据点与之前训练模型的原始数据的分布完全不同。因此,我想借助新的数据点来更新/微调模型。如何去做呢?有人可以帮我吗?我正在使用pytorch 0.4.0来实现,并在GPU Tesla K40C上运行。

1 个答案:

答案 0 :(得分:1)

如果您不想更改分类器的输出(即类别数),则可以简单地继续使用新的示例图像训练模型,并假设它们已重塑为与预训练模型相同的形状接受。

另一方面,如果要更改预训练模型中的类数,则可以用新的替换最后一个完全连接的层,并仅在新样本上训练该特定层。这是来自PyTorch's autograd mechanics notes的这种情况的示例代码:

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)