如何在Pytorch中覆盖预训练的模型类?

时间:2020-02-03 20:49:31

标签: deep-learning pytorch resnet

我试图使用经过预训练的ResNet(2 + 1)D [1],但是由于它的第一层使用3个通道,而我只使用一个通道,所以我想必须重写该类。请看看我的尝试,我收到一个错误:

TypeError: _video_resnet() got multiple values for keyword argument 'stem'

[1] https://pytorch.org/docs/stable/_modules/torchvision/models/video/resnet.html#r2plus1d_18

代码:

class R2Plus1dStem4IMAGES(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """
    def __init__(self):
        super(R2Plus1dStem4IMAGES, self).__init__(
            nn.Conv3d(1, 45, kernel_size=(1, 7, 7),
                      stride=(1, 2, 2), padding=(0, 3, 3),
                      bias=False),
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
                      stride=(1, 1, 1), padding=(1, 0, 0),
                      bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))



model = torchvision.models.video.r2plus1d_18(pretrained=True, stem=R2Plus1dStem4IMAGES)

model.fc = nn.Linear(model.fc.in_features, 3)

0 个答案:

没有答案