我正在尝试下面的链接中的pix2pixHD代码 https://github.com/NVIDIA/pix2pixHD
train.py使用默认图像(在数据集/城市景观中)。但是,在更改数据集中的图像后,它会显示以下错误。
model [Pix2PixHDModel] was created
create web directory ./checkpoints/label2city/web...
Traceback (most recent call last):
File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/train.py", line 58, in <module>
Variable(data['image']), Variable(data['feat']), infer=save_fake)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 66, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/models/pix2pixHD_model.py", line 141, in forward
fake_image = self.netG.forward(input_concat)
File "/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/models/networks.py", line 213, in forward
return self.model(input)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/container.py", line 67, in forward
input = module(input)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/modules/conv.py", line 277, in forward
self.padding, self.dilation, self.groups)
File "/home/shimada/venv/py2.7/local/lib/python2.7/site-packages/torch/nn/functional.py", line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given groups=1, weight[64, 36, 7, 7], so expected input[1, 39, 518, 1030] to have 36 channels, but got 39 channels instead
THCudaCheck FAIL file=/pytorch/torch/lib/THC/generic/THCStorage.c line=184 error=59 : device-side assert triggered
terminate called after throwing an instance of 'std::runtime_error'
what(): cuda runtime error (59) : device-side assert triggered at /pytorch/torch/lib/THC/generic/THCStorage.c:184
bash: line 1: 10965 Aborted (core dumped) env "PYCHARM_HOSTED"="1" "PYTHONUNBUFFERED"="1" "PYTHONIOENCODING"="UTF-8" "PYCHARM_MATPLOTLIB_PORT"="42188" "JETBRAINS_REMOTE_RUN"="1" "PYTHONPATH"="/home/shimada/.pycharm_helpers/pycharm_matplotlib_backend:/home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD" /home/shimada/venv/py2.7/bin/python -u /home/shimada/venv/py2.7/projects/Hiwi/pix2pixHD/train.py
我更改了相同大小的图像(宽度2048,高度1024),相同的扩展名(.png)并给出了相同的名称。为什么它不起作用?
答案 0 :(得分:0)
您的原始图像/地面真实数据似乎是灰度的。在这种情况下,您必须定义--input_nc 1 --output_nc 1
表示灰度。您还必须更改pix2pixHD代码才能加载灰度图像。