我试图使用经过预训练的ResNet(2 + 1)D [1],但是由于它的第一层使用3个通道,而我只使用一个通道,所以我想必须重写该类。请看看我的尝试,我收到一个错误:
TypeError: _video_resnet() got multiple values for keyword argument 'stem'
[1] https://pytorch.org/docs/stable/_modules/torchvision/models/video/resnet.html#r2plus1d_18
代码:
class R2Plus1dStem4IMAGES(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
super(R2Plus1dStem4IMAGES, self).__init__(
nn.Conv3d(1, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
nn.BatchNorm3d(45),
nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))
model = torchvision.models.video.r2plus1d_18(pretrained=True, stem=R2Plus1dStem4IMAGES)
model.fc = nn.Linear(model.fc.in_features, 3)