pix2pixHD在更改数据集后显示错误

时间:2017-12-08 16:44:55

标签: deep-learning pytorch

我正在尝试下面的链接中的pix2pixHD代码 https://github.com/NVIDIA/pix2pixHD

train.py使用默认图像(在数据集/城市景观中)。但是,在更改数据集中的图像后,它会显示以下错误。

    model [Pix2PixHDModel] was created
    create web directory ./checkpoints/label2city/web...
    Traceback (most recent call last):
      File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/train.py", line 58, in <module>
        Variable(data['image']), Variable(data['feat']), infer=save_fake)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 66, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/models/pix2pixHD_model.py", line 141, in forward
        fake_image = self.netG.forward(input_concat)
      File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/models/networks.py", line 213, in forward
        return self.model(input)             
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/container.py", line 67, in forward
        input = module(input)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/conv.py", line 277, in forward
        self.padding, self.dilation, self.groups)
      File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/functional.py", line 90, in conv2d
        return f(input, weight, bias)
    RuntimeError: Given groups=1, weight[64, 36, 7, 7], so expected input[1, 39, 518, 1030] to have 36 channels, but got 39 channels instead
    THCudaCheck FAIL file=/pytorch/torch/lib/THC/generic/THCStorage.c line=184 error=59 : device-side assert triggered
    terminate called after throwing an instance of 'std::runtime_error'
      what():  cuda runtime error (59) : device-side assert triggered at /pytorch/torch/lib/THC/generic/THCStorage.c:184
    bash: line 1: 10965 Aborted                 (core dumped) env "PYCHARM_HOSTED"="1" "PYTHONUNBUFFERED"="1" "PYTHONIOENCODING"="UTF-8" "PYCHARM_MATPLOTLIB_PORT"="42188" "JETBRAINS_REMOTE_RUN"="1" "PYTHONPATH"="/home/shimada/.pycharm_helpers/pycharm_matplotlib_backend:/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD" /home/shimada/venv/py2.7/bin/python -u /home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/train.py

我更改了相同大小的图像(宽度2048,高度1024),相同的扩展名(.png)并给出了相同的名称。为什么它不起作用?

1 个答案:

答案 0 :(得分:0)

您的原始图像/地面真实数据似乎是灰度的。在这种情况下,您必须定义--input_nc 1 --output_nc 1表示灰度。您还必须更改pix2pixHD代码才能加载灰度图像。