如何在pytorch的自定义数据集上训练deeplabv3?

时间:2020-09-14 21:25:31

标签: python deep-learning pytorch deeplab

进口火炬 型号= torch.hub.load('pytorch / vision:v0.6.0','deeplabv3_resnet101',pretrained = True) model.eval()

Pytorch在Pascal数据集上提供了预训练的deeplabv3,我想在城市景观上对其进行训练。有效的方法是什么?

1 个答案:

答案 0 :(得分:1)

  • 编写custom Dataloader类,该类应继承Dataset类并实现至少两个方法 __len____getitem__
  • 用自定义数量的输出通道修改预训练的DeeplabV3头。
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101

def custom_DeepLabv3(out_channel):
  model = deeplabv3_resnet101(pretrained=True, progress=True)
  model.classifier = DeepLabHead(2048, out_channel)

  #Set the model in training mode
  model.train()
  return model
  • 训练并评估模型。