state_dict中缺少键

时间:2019-04-18 11:21:40

标签: image-processing machine-learning computer-vision pytorch generative-adversarial-network

我在将模型加载到Google colab时遇到问题。这是代码:

我已附上以下代码

我尝试更改statedict的名称,但无济于事 基本上,我正在尝试保存模型以备后用,但是由于我无法正确保存和加载模型,这变得非常困难。请帮我解决这个问题。在代码部分之后,您还将在下面找到我所附的错误。

这是代码

from zipfile import ZipFile
file_name = 'data.zip'
with ZipFile(file_name, 'r') as zip:
  zip.extractall()

from zipfile import ZipFile
file_name = 'results.zip'
with ZipFile(file_name, 'r') as zip:
  zip.extractall()

!pip install tensorflow-gpu

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


batchSize = 64 
imageSize = 64 

transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) 

dataset = dset.CIFAR10(root = './data', download = True, transform = transform) 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output


netG = G()
netG.load_state_dict(torch.load('generator.pth'))
netG.eval()
#netG.apply(weights_init)



class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)


netD = D()
netD.load_state_dict(torch.load('discriminator.pth'))
netD.eval()
#netD.apply(weights_init)


criterion = nn.BCELoss()
checkpoint = torch.load('discriminator.pth')
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
errD = checkpoint['loss']
checkpoint1 = torch.load('genrator.pth')
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG.load_state_dict(checkpoint1['optimizer_state_dict'])
errG = checkpoint1['loss']
k = epoch
for j in range(k, 10):

    for i, data in enumerate(dataloader, 0):


        netD.zero_grad()


        real, _ = data
        input = Variable(real)
        target = Variable(torch.ones(input.size()[0]))
        output = netD(input)
        errD_real = criterion(output, target)


        noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0]))
        output = netD(fake.detach())
        errD_fake = criterion(output, target)


        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()



        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0]))
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()



        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch+1, 10, i+1, len(dataloader), errD.data, errG.data))
        if i % 100 == 0:
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch+1), normalize = True)

torch.save({
            'epoch': epoch,
            'model_state_dict': netD.state_dict(),
            'optimizer_state_dict': optimizerD.state_dict(),
            'loss': errD
            }, 'discriminator.pth')
torch.save({
            'epoch': epoch,
            'model_state_dict': netG.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
            'loss': errG
            }, 'generator.pth')

这是错误

RuntimeError                              Traceback (most recent call last)
<ipython-input-23-3e55546152c7> in <module>()
     26 # Creating the generator
     27 netG = G()
---> 28 netG.load_state_dict(torch.load('generator.pth'))
     29 netG.eval()
     30 #netG.apply(weights_init)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for G:
    Missing key(s) in state_dict: "main.0.weight", "main.1.weight", "main.1.bias", "main.1.running_mean", "main.1.running_var", "main.3.weight", "main.4.weight", "main.4.bias", "main.4.running_mean", "main.4.running_var", "main.6.weight", "main.7.weight", "main.7.bias", "main.7.running_mean", "main.7.running_var", "main.9.weight", "main.10.weight", "main.10.bias", "main.10.running_mean", "main.10.running_var", "main.12.weight". 
    Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss".

1 个答案:

答案 0 :(得分:1)

您需要访问已加载检查点内的extra_capacity键。
试试:

'model_state_dict'

您可能还需要对鉴别器应用相同的修复程序。