在Pytorch中传递模块列表时语法错误无效

时间:2018-01-19 22:43:11

标签: python python-2.7 pytorch

我的深模型中有两个块,定义如下:

def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return [
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]

def make_conv_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return [
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
        nn.ReLU(inplace=True),
    ]

现在,我想在nn.Sequential传递它。

self.down1 = nn.Sequential(*make_conv_bn_relu(in_channels, 16, kernel_size=3, stride=1, padding=1 ), *make_conv_bn_relu(16, 32, kernel_size=3, stride=2, padding=1 ),)

但是我收到以下错误:

Traceback (most recent call last):
  File "train_unet.py", line 17, in <module>
    from net.model.unet1 import UNet256_3x3 as Net
  File "/home/avijit.d/Kaggle/Pytorch/source/dummy-01/net/model/unet1.py", line 40
    self.down1 = nn.Sequential(*make_conv_bn_relu(in_channels, 16, kernel_size=3, stride=1, padding=1 ), *make_conv_bn_relu(16, 32, kernel_size=3, stride=2, padding=1 ),)
                                                                                                         ^
SyntaxError: invalid syntax

如何摆脱这个?我使用的是Python 2.7

1 个答案:

答案 0 :(得分:1)

你不能在python2中使用多个解压缩包。但如果你真的想使用它,那么只需连接列表:

nn.Squential(*(make_foo() + make_bar()))