我可以同时训练两个网络,一个网络包含另一个网络。

时间:2019-10-24 06:35:13

标签: networking pytorch

我在模型中定义了两个网络。

def Net_One():
    conv2d
    conv2d
    ...
def Net_Two():
    Net_One(input)
    conv2d
    fc

所以我的问题是:当我使用反向传播优化训练Net_Two时,pytorch是否会自动训练Net_One?为什么?

1 个答案:

答案 0 :(得分:0)

首先,模型类名称的约定应为CapWordsNetOneNetTwo,但这并没有什么害处,只是约定。

关于您的问题,这取决于NetTwo()的处理。
如果NetTwo的最终损失与NetOne()没有关系,则反向传播将不会流经NetOne(),因此不会更新NetOne()的参数。 否则,反向传播将计算NetOne()的梯度并更新其权重。

对于代码示例:

# NetTwo's loss has nothing to do with NetOne:
def NetOne():
    def __init__(self):
        super(NetOne, self).__init__()
        ...

    def forward(inputs):
        ...


def NetTwo():
    def __init__(self):
        super(NetTwo, self).__init__()
        ...

    def forward(inputs):
        # in this processing, temp is never used by NetTwo's layers..
        ...
        temp = NetOne(inputs)
        inputs = conv2d(inputs)
        ...

在高位代码中,从未使用temp,因此在NetTwo更新时不会更新NetOne。

但是如果NetTwo.forward()使用temp,它将被更新,如下所示:

def NetTwo():
    def __init__(self):
        super(NetTwo, self).__init__()
        ...

    def forward(inputs):
        # in this processing, temp is used by NetTwo's layers.
        ...
        temp = NetOne(inputs)
        inputs = conv2d(temp)
        ...

此答案回答了您的问题吗?