PyTorch-将ProGAN代理从pth转换为onnx

时间:2020-03-25 08:31:55

标签: python machine-learning pytorch onnx

我使用this PyTorch重新实现来训练ProGAN代理,并将代理另存为chown -R www-data:www-data wordpress 。现在,我需要将代理转换为.pth格式,我正在使用以下scipt:

.onnx

一旦运行它,我将收到错误from torch.autograd import Variable import torch.onnx import torchvision import torch device = torch.device("cuda") dummy_input = torch.randn(1, 3, 64, 64) state_dict = torch.load("GAN_agent.pth", map_location = device) torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx") (下面的完整提示)。据我了解,问题在于将代理转换为.onnx需要更多信息。我想念什么吗?

AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

1 个答案:

答案 0 :(得分:3)

这里有state_dict个文件,它们只是层名称到tensor的权重偏差和类似的映射(有关更全面的介绍,请参见here)。

这意味着您需要一个模型,以便可以将保存的权重和偏差映射到上面,但首先要考虑的是

1。模型准备

克隆the repository,并在其中定义模型,并打开文件/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py。我们需要进行一些修改才能使其与onnx一起使用。 onnx出口商仅要求input作为torch.tensor传递(或其中的list / dict),而Generator类需要{{1 }}和int参数。

一个简单的解决方案,它可以对float函数(文件中的forward行,您可以on GitHub进行验证)进行以下修改:

80

此处仅添加了通过def forward(self, x, depth, alpha): """ forward pass of the Generator :param x: input noise :param depth: current depth from where output is required :param alpha: value of alpha for fade-in effect :return: y => output """ # THOSE TWO LINES WERE ADDED # We will pas tensors but unpack them here to `int` and `float` depth = depth.item() alpha = alpha.item() # THOSE TWO LINES WERE ADDED assert depth < self.depth, "Requested output depth cannot be produced" y = self.initial_block(x) if depth > 0: for block in self.layers[: depth - 1]: y = block(y) residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y)) straight = self.rgb_converters[depth](self.layers[depth - 1](y)) out = (alpha * straight) + ((1 - alpha) * residual) else: out = self.rgb_converters[0](y) return out 开箱的信息。每个非item()类型的输入都应在函数定义中打包为一个,并在函数顶部尽快解压缩。它不会破坏您创建的检查点,因此不必担心,因为它只是Tensor映射。

2。模型导出

将此脚本放置在layer-weight(也位于/pro_gan_pytorch的位置)中:

README.md

请注意一些事项:

  • 我们只能在加载权重之前创建模型,因为它仅是import torch from pro_gan_pytorch import PRO_GAN as pg gen = torch.nn.DataParallel(pg.Generator(depth=9)) gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth")) module = gen.module.to("cpu") # Arguments like depth and alpha may need to be changed dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1])) torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)
  • 需要
  • state_dict,因为这是模型训练的依据(不确定您的情况,请进行相应调整)。加载后,我们可以通过torch.nn.DataParallel属性获取模块本身。
  • 所有内容都强制转换为module,我认为这里不需要CPU。如果您仍然坚持的话,可以将所有内容投放到GPU
  • 生成器的
  • 虚拟输入不能是图像(我使用的是回购作者on their Google Drive提供的文件),它必须带有GPU元素。

运行它,您的512文件应该在那里。

哦,当您在不同的检查点之后时,您可能希望遵循类似的步骤,尽管不能保证一切都可以正常工作(尽管看起来确实如此)。