我这里有一个很简单的问题。
我刚刚完成了重新配置网络的操作,将nn.Upsample
替换为下面代码中所示的upConv
顺序容器。我已经通过运行summary(UNetPP, (3, 128, 128))
验证了所有内容,并且运行没有问题。
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class blockUNetPP(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
class upConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.upc = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch*2, 3, stride=1, padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.upc(x)
return out
我的问题是,当我尝试开始训练模型时,出现以下问题:
Traceback (most recent call last):
File "runTrain.py", line 90, in <module>
netG.apply(weights_init)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 289, in apply
module.apply(fn)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 290, in apply
fn(self)
File "D:\Thesis Models\Deep_learning_models\UNet\train\NetC.py", line 8, in weights_init
m.weight.data.normal_(0.0, 0.02)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 594, in __getattr__
type(self).__name__, name))
AttributeError: 'upConv' object has no attribute 'weight'
我查看了solutions,这建议循环遍历容器模块,但是我已经使用weights_init(m)
进行了此操作。有人可以解释我当前设置的问题吗?
答案 0 :(得分:1)
您正在确定如何初始化权重,方法是检查类名是否包含classname.find('Conv')
的 Conv 。您的类的名称为 upConv ,其中包括 Conv ,因此您尝试初始化其属性.weight
,但该属性不存在。
重命名您的类或使条件更严格,例如classname.find('Conv2d')
。最严格的方法是检查它是否为nn.Conv2d
的实例,而不是查看类的名称。
def weights_init(m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)