如何在PyTorch中修改预训练的Torchvision模型以返回两个输出以进行多标签图像分类

时间:2019-07-10 18:37:27

标签: python-3.x conv-neural-network pytorch multilabel-classification torchvision

输入:一组十个“元音”,一组十个“辅音”,图像数据集,在每个图像中都写入了一个元音和一个辅音。

任务:从给定图像中识别元音和辅音。

方法:首先在图像上应用CNN隐藏层,然后应用两个并行的完全连接/密集层,其中一层将图像中的元音分类,另一层将图像中的辅音分类。

问题:我正在使用像VGG或GoogleNet这样的预训练模型。如何修改该预训练模型以应用两个平行的密集层并返回两个输出。

我尝试了两种不同的模型,但我的查询是我们是否可以修改此任务的预训练模型。

现在,我的模型只有一层“ fc”。我在最后的“ fc”层修改了神经元的数量,像这样

final_in_features = googlenet.fc.in_features

googlenet.fc = nn.Linear(final_in_features, 10)

但是我需要再增加一层fc层,以便两个“ fc”层都与隐藏层并行连接。

现在模型仅返回一个输出。

outputs1 = googlenet(inputs)

任务是从两个“ fc”层返回两个输出,因此它应该看起来像这样

outputs1, outputs2 = googlenet(inputs)

2 个答案:

答案 0 :(得分:1)

以下是Pytorch中线性层的来源:

class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

您可以像这样创建DoubleLinear类:

class DoubleLinear(Module):
    def __init__(self, Linear1, Linear2):
        self.Linear1 = Linear1
        self.Linear2 = Linear2
    @weak_script_method
    def forward(self, input):
        return self.Linear1(input), self.Linear2(input)

然后,创建两个线性层:

Linear_vow = nn.Linear(final_in_features, 10)
Linear_con = nn.Linear(final_in_features, 10)
final_layer = DoubleLinear(Linear_vow, Linear_con)

现在outputs1, outputs2 = final_layer(inputs)将按预期工作。

答案 1 :(得分:0)

class DoubleLinear(torch.nn.Module):
    def __init__(self, Linear1, Linear2):
        super(DoubleLinear, self).__init__()
        self.Linear1 = Linear1
        self.Linear2 = Linear2

    def forward(self, input):
        return self.Linear1(input), self.Linear2(input)


in_features = model._fc.in_features

Linear_first = nn.Linear(in_features, 10)
Linear_second = nn.Linear(in_features, 5)

model._fc = DoubleLinear(Linear_first, Linear_second)