从头开始训练 pytorch 模型需要更改哪些参数?

时间:2020-12-19 00:34:24

标签: python deep-learning pytorch

我按照本教程训练了一个用于实例分割的 pytorch 模型: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

我不想在与 COCO 完全无关的完全不同的数据和类上训练模型。我需要做哪些改变来重新训练模型。根据我的阅读,我猜除了正确数量的课程外,我只需要训练这条线:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

但我注意到还有另一个参数:pretrained_backbone=True, trainable_backbone_layers=None 是否也应该更改?

2 个答案:

答案 0 :(得分:1)

函数签名是

<块引用>

Please enter a number!

设置 library(dplyr) library(stringr) df = data.frame( first_bp = c("120/80","90/60"), id = c("0001234","0001235"), amount = c(18.50, -18.50), stringsAsFactors = F) df %>% mutate(s0 = str_split(first_bp,"/")) %>% rowwise() %>% mutate(systole = as.numeric(s0[1]), diastole = as.numeric(s0[2])) %>% select(first_bp, id, amount, systole, diastole) 会告诉 PyTorch 不要下载在 COCO train2017 上预训练的模型。你想要它,因为你对培训感兴趣。

通常,如果您想在不同的数据集上进行训练,这就足够了。

当您设置 torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs) 时,PyTorch 将在 ImageNet 上下载预训练的 ResNet50。默认情况下,它会冻结前两个名为 pretrained=Falsepretrained=False 的块。这就是 Faster R-CNN 论文中的方法,它冻结了预训练主干的初始层。

(只需打印模型以检查其结构)。

conv1

现在,如果你甚至不想让前两层冻结,你可以设置layer1(当你设置layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] 时自动完成),这将从头开始训练整个resnet主干.

检查PR#2160

答案 1 :(得分:0)

来自maskrcnn_resnet50_fpn document

  • pretrained (bool) – 如果为 True,则返回在 COCO train2017 上预训练的模型
  • pretrained_backbone (bool) – 如果为 True,则返回在 Imagenet 上预训练了主干的模型
  • trainable_backbone_layers (int) – 从最后一个块开始的可训练(未冻结)resnet 层的数量。有效值介于 0 到 5 之间,其中 5 表示所有主干层都是可训练的。

所以从头开始训练使用:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, trainable_backbone_layers=5, num_classes=your_num_classes)

或:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, num_classes=your_num_classes)

因为在 maskrcnn_resnet50_fpn 的源代码中:

if not (pretrained or pretrained_backbone):
    trainable_backbone_layers = 5
相关问题