我正在尝试使用pix2pixHD预训练模型生成我自己的图像。 Github repo found here
数据集内的图像必须是灰度,没有alpha通道。 repo中的图像大小为16 bitPerSample,我有两个大小为8和16 bitsPerSample的图像。
当我使用sips -g all
检查我的图像和回购中的图像时。这是我得到的结果:
pixelWidth: 2048
pixelHeight: 1024
typeIdentifier: public.png
format: png
formatOptions: default
dpiWidth: 72.000
dpiHeight: 72.000
samplesPerPixel: 1
bitsPerSample: 16
hasAlpha: no
space: Gray
奇怪的是,它适用于具有8 bitPerSample的图像。 这是我得到的结果:
当我使用16位PerSample图像运行test.py
时,它不起作用。
这是它给我的错误:
model [Pix2PixHDModel] was created
Traceback (most recent call last):
File "test.py", line 26, in <module>
for i, data in enumerate(dataset):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 210, in __next__
return self._process_next_batch(batch)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 230, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
TypeError: Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/paperspace/Documents/pix2pixHD/data/aligned_dataset.py", line 41, in __getitem__
label_tensor = transform_label(label) * 255.0
File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 309, in __mul__
return self.mul(other)
TypeError: mul received an invalid combination of arguments - got (float), but expected one of:
* (int value)
didn't match because some of the arguments have invalid types: (float)
* (torch.IntTensor other)
didn't match because some of the arguments have invalid types: (float)
我对Tensorflow很新,我之前从未使用过pytorch。
知道这个错误意味着什么,我该如何解决?
答案 0 :(得分:1)
是的,我想我可以帮到你。 我还没有检查过存储库,但是从错误跟踪中看到的问题似乎如下:
您正在transform_label(label)
(可能是张量)和标量255.0
的输出之间执行乘法运算。只要你的标量和张量都是相同的datatype
,这就没问题了。但是,从错误跟踪中看,transform_label()
的输出似乎是数据类型Int / Long
,而255.0
是浮点数。
我建议您尝试255
或int(255.0)
而不是255.0
。
如果这不能解决您的问题,请告诉我transform_label()
的输出是什么数据类型。