火炬网络负载未正确处理

时间:2019-11-12 18:28:47

标签: python pytorch

我试图在pytorch环境中使用3x64x64映像建立网络,看来我成功地训练了我的网络并保存了它。网络看起来像:

class LC_small(nn.Module):
    def __init__(self,c_in,c_out = 256):
    super(LC_small,self).__init__()
        self.conv1 = conv(c_in,64,k=3,stride=1,pad=1)
        self.conv2 = conv(64, 128, k=3, stride=2, pad=1)
        self.conv3 = conv(128, 128, k=3, stride=1, pad=1)
        self.conv4 = conv(128, 128, k=3, stride=2, pad=1)
        self.conv5 = conv(128, 128, k=3, stride=1, pad=1)
        self.conv6 = conv(128, 256, k=3, stride=2, pad=1)
        self.conv7 = conv(256, 256, k=3, stride=1, pad=1)# int(h/8 x w/8 x 256)
        self.flat = dense(int(w_rsz/8)*int(h_rsz/8)*256,256)
        self.dense1 = dense(256,128,False)
        self.dense2 = dense(128,3,False)
    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = out.view(out.size(0),-1)
        out = self.flat(out)
        out = self.dense1(out)
        out = self.dense2(out)
         # print(out.shape)
        normal = torch.nn.functional.normalize(out, 2, 1)

        return normal

然后在Training时保存了我的模型:

for epoch in range(10):
#  continue    # 현재 Training 됐다고 가정하고
    total_loss = 0
    route_param = open(route_diffuse+'/netparam.txt','w')
    for param in lcnet.state_dict():
    route_param.write(str(param)+'\t'+str(lcnet.state_dict()[param].size())+'\n')
    for i,data in enumerate(load_LC,0):
    input, gtval = data[0].to(dev),data[1].to(dev)
    opt.zero_grad()

    output = lcnet(input)
    loss = crit(output,gtval)
    loss.backward()
    opt.step()
    total_loss +=loss.item()
    if i%10 == 9:
         print(epoch,i,total_loss/10)
         torch.save(lcnet,route_save)
         total_loss = 0

但是,当我尝试加载网络时,我看到了如下错误消息:

Traceback (most recent call last):

File "E:/DLPrj/venv/torch_practice.py", line 324, in <module>

ipl,npl = getseqi_np(sq_t,lcnet)   #  data : 8 x 6 x w x h 

File "E:/DLPrj/venv/torch_practice.py", line 133, in getseqi_np

l1 = net_lc(torch.from_numpy(i1r))

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:/DLPrj/venv/torch_practice.py", line 216, in forward

out = self.conv1(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\container.py", line 92, in forward

input = module(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 345, in forward

return self.conv2d_forward(input, self.weight)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 342, in conv2d_forward

self.padding, self.dilation, self.groups)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3 3, but got 3-dimensional input of size [64, 64, 3] instead

此错误导致pycharm死机,直到重新启动pycharm,我才能重新运行此代码。

训练网络时,我还会收到一些警告消息:

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LC_small. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Conv2d. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type BatchNorm2d. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LeakyReLU. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

我不明白为什么网络需要的输入大小突然改变,或者为什么它会错误地保存我的网络。请检查我的问题,非常感谢。

1 个答案:

答案 0 :(得分:0)

因此,您的第一个错误消息是因为torch.from_numpy(i1r)的形状错误。您需要

<p><strong>Lorem</strong> ipsum dolor asdjfgjasfgbjgj gjbhgjgjgjkghkj; lkfakj ;ljljfl sdfadfasfd asera

然后它将得到正确处理。这是因为它需要一个批处理维度,并且您没有在第一个维度(而不是最后一个维度)中提供渠道。

关于第二条错误消息,可能是因为您错误地定义了conv且密集,因此在保存模型时会弄乱。