如何使用另一个类的返回值?

时间:2017-12-03 13:23:55

标签: python pytorch

我的代码如下所示:

class RegularConv(nn.module):
    def __init__(self):
        super().__init__.()
        self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)

    def forward (self, inputs):
        inputs, indices = self.pool1(inputs)
        return inputs, indices

class UpsamplingConv(nn.module):
    def __init__(self, indices):
        super().__init__.()
        self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)

    def forward (self, output):
        output = self.unpool(output, indices, output_size=output_size)
        return output
class Net(nn.module):
    def __init__(self):
        super().__init__.()
        self.layer1 = RegularConv()
        self.layer2 = UpsamplingConv()

    def forward(self, input):
        input, indices = self.layer1(input)
        input = self.layer2(input, indices)

RegularConv类中的forward方法返回输入和索引,UpsamplingConv类中的forward方法必须使用这些索引。

我在另一个类Net中使用这两个类。我想使用一个类的返回值并在另一个类中使用。我想将RegularConv中返回的索引作为UpsamplingConv

中的参数传递

1 个答案:

答案 0 :(得分:0)

您可以尝试将其作为下面的列表返回,

def forward (self, inputs):
        inputs, indices = self.pool1(sub_inputs)
        return [inputs, indices]

并使用列表的索引来调用该参数。

test = UpsamplingConv(RegularConv()[1])