与自己的数据集pix2pixHD错误

时间:2017-12-14 18:12:18

标签: tensorflow pytorch

我正在尝试使用pix2pixHD预训练模型生成我自己的图像。 Github repo found here

数据集内的图像必须是灰度,没有alpha通道。 repo中的图像大小为16 bitPerSample,我有两个大小为8和16 bitsPerSample的图像。

当我使用sips -g all检查我的图像和回购中的图像时。这是我得到的结果:

pixelWidth: 2048
pixelHeight: 1024
typeIdentifier: public.png
format: png
formatOptions: default
dpiWidth: 72.000
dpiHeight: 72.000
samplesPerPixel: 1
bitsPerSample: 16
hasAlpha: no
space: Gray

奇怪的是,它适用于具有8 bitPerSample的图像。 这是我得到的结果:

灰度输入 grayscale 转换后的标签图 Input 最终输出 output

当我使用16位PerSample图像运行test.py时,它不起作用。 这是它给我的错误:

model [Pix2PixHDModel] was created
Traceback (most recent call last):
  File "test.py", line 26, in <module>
    for i, data in enumerate(dataset):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 210, in __next__
    return self._process_next_batch(batch)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 230, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
TypeError: Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/paperspace/Documents/pix2pixHD/data/aligned_dataset.py", line 41, in __getitem__
    label_tensor = transform_label(label) * 255.0
  File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 309, in __mul__
    return self.mul(other)
TypeError: mul received an invalid combination of arguments - got (float), but expected one of:
 * (int value)
      didn't match because some of the arguments have invalid types: (float)
 * (torch.IntTensor other)
      didn't match because some of the arguments have invalid types: (float)

我对Tensorflow很新,我之前从未使用过pytorch。

知道这个错误意味着什么,我该如何解决?

1 个答案:

答案 0 :(得分:1)

是的,我想我可以帮到你。 我还没有检查过存储库,但是从错误跟踪中看到的问题似乎如下:

您正在transform_label(label)(可能是张量)和标量255.0的输出之间执行乘法运算。只要你的标量和张量都是相同的datatype,这就没问题了。但是,从错误跟踪中看,transform_label()的输出似乎是数据类型Int / Long,而255.0是浮点数。

我建议您尝试255int(255.0)而不是255.0

如果这不能解决您的问题,请告诉我transform_label()的输出是什么数据类型。