我的代码如下所示:
class RegularConv(nn.module):
def __init__(self):
super().__init__.()
self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)
def forward (self, inputs):
inputs, indices = self.pool1(inputs)
return inputs, indices
class UpsamplingConv(nn.module):
def __init__(self, indices):
super().__init__.()
self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)
def forward (self, output):
output = self.unpool(output, indices, output_size=output_size)
return output
class Net(nn.module):
def __init__(self):
super().__init__.()
self.layer1 = RegularConv()
self.layer2 = UpsamplingConv()
def forward(self, input):
input, indices = self.layer1(input)
input = self.layer2(input, indices)
RegularConv
类中的forward方法返回输入和索引,UpsamplingConv
类中的forward方法必须使用这些索引。
我在另一个类Net
中使用这两个类。我想使用一个类的返回值并在另一个类中使用。我想将RegularConv
中返回的索引作为UpsamplingConv
答案 0 :(得分:0)
您可以尝试将其作为下面的列表返回,
def forward (self, inputs):
inputs, indices = self.pool1(sub_inputs)
return [inputs, indices]
并使用列表的索引来调用该参数。
test = UpsamplingConv(RegularConv()[1])