我更改了以下代码
encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),
Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))
encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),
Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))
encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),
Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))
使用
encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 32, 32).cuda()),
Variable(torch.zeros(data.size(0), 256, 32, 32).cuda()))
encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 16, 16).cuda()),
Variable(torch.zeros(data.size(0), 512, 16, 16).cuda()))
encoder_h_3 = (Variable(torch.zeros(data.size(0), 1024, 8, 8).cuda()),
Variable(torch.zeros(data.size(0), 1024, 8, 8).cuda()))
encoder_h_4 = (Variable(torch.zeros(data.size(0), 2048, 4, 4).cuda()),
Variable(torch.zeros(data.size(0), 2048, 4, 4).cuda()))
encoder_h_5 = (Variable(torch.zeros(data.size(0), 2048, 2, 2).cuda()),
Variable(torch.zeros(data.size(0), 2048, 2, 2).cuda()))
我想实现第二个代码。如何摆脱错误?