进口火炬 型号= torch.hub.load('pytorch / vision:v0.6.0','deeplabv3_resnet101',pretrained = True) model.eval()
Pytorch在Pascal数据集上提供了预训练的deeplabv3,我想在城市景观上对其进行训练。有效的方法是什么?
答案 0 :(得分:1)
__len__
和__getitem__
。from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101
def custom_DeepLabv3(out_channel):
model = deeplabv3_resnet101(pretrained=True, progress=True)
model.classifier = DeepLabHead(2048, out_channel)
#Set the model in training mode
model.train()
return model