我试图在pytorch环境中使用3x64x64映像建立网络,看来我成功地训练了我的网络并保存了它。网络看起来像:
class LC_small(nn.Module):
def __init__(self,c_in,c_out = 256):
super(LC_small,self).__init__()
self.conv1 = conv(c_in,64,k=3,stride=1,pad=1)
self.conv2 = conv(64, 128, k=3, stride=2, pad=1)
self.conv3 = conv(128, 128, k=3, stride=1, pad=1)
self.conv4 = conv(128, 128, k=3, stride=2, pad=1)
self.conv5 = conv(128, 128, k=3, stride=1, pad=1)
self.conv6 = conv(128, 256, k=3, stride=2, pad=1)
self.conv7 = conv(256, 256, k=3, stride=1, pad=1)# int(h/8 x w/8 x 256)
self.flat = dense(int(w_rsz/8)*int(h_rsz/8)*256,256)
self.dense1 = dense(256,128,False)
self.dense2 = dense(128,3,False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
out = out.view(out.size(0),-1)
out = self.flat(out)
out = self.dense1(out)
out = self.dense2(out)
# print(out.shape)
normal = torch.nn.functional.normalize(out, 2, 1)
return normal
然后在Training时保存了我的模型:
for epoch in range(10):
# continue # 현재 Training 됐다고 가정하고
total_loss = 0
route_param = open(route_diffuse+'/netparam.txt','w')
for param in lcnet.state_dict():
route_param.write(str(param)+'\t'+str(lcnet.state_dict()[param].size())+'\n')
for i,data in enumerate(load_LC,0):
input, gtval = data[0].to(dev),data[1].to(dev)
opt.zero_grad()
output = lcnet(input)
loss = crit(output,gtval)
loss.backward()
opt.step()
total_loss +=loss.item()
if i%10 == 9:
print(epoch,i,total_loss/10)
torch.save(lcnet,route_save)
total_loss = 0
但是,当我尝试加载网络时,我看到了如下错误消息:
Traceback (most recent call last):
File "E:/DLPrj/venv/torch_practice.py", line 324, in <module>
ipl,npl = getseqi_np(sq_t,lcnet) # data : 8 x 6 x w x h
File "E:/DLPrj/venv/torch_practice.py", line 133, in getseqi_np
l1 = net_lc(torch.from_numpy(i1r))
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "E:/DLPrj/venv/torch_practice.py", line 216, in forward
out = self.conv1(input)
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\container.py", line 92, in forward
input = module(input)
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 345, in forward
return self.conv2d_forward(input, self.weight)
File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 342, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3 3, but got 3-dimensional input of size [64, 64, 3] instead
此错误导致pycharm死机,直到重新启动pycharm,我才能重新运行此代码。
训练网络时,我还会收到一些警告消息:
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LC_small. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Conv2d. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type BatchNorm2d. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LeakyReLU. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
我不明白为什么网络需要的输入大小突然改变,或者为什么它会错误地保存我的网络。请检查我的问题,非常感谢。
答案 0 :(得分:0)
因此,您的第一个错误消息是因为torch.from_numpy(i1r)的形状错误。您需要
<p><strong>Lorem</strong> ipsum dolor asdjfgjasfgbjgj gjbhgjgjgjkghkj; lkfakj ;ljljfl sdfadfasfd asera
然后它将得到正确处理。这是因为它需要一个批处理维度,并且您没有在第一个维度(而不是最后一个维度)中提供渠道。
关于第二条错误消息,可能是因为您错误地定义了conv且密集,因此在保存模型时会弄乱。