是火炬.FloatTensor不是Tensor?

时间:2018-03-03 20:03:43

标签: python pytorch

虽然这个例子没有经过培训,但这是一个大型项目的改编版,其中培训确实发生了。我只是希望generator网络在这种情况下喷出随机图像:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

class Generator(nn.Module):

    def __init__(self):
        """
        Generator component of GAN. requires an input slightly bigger 
        than 300 x 300 (precisely 308 x 308)
        """
        super(Generator, self).__init__()

        # 5 x 5 square convolution.
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 4, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


def main():
    # Generate example image.
    generator = Generator()
    img = generator(Variable(torch.randn(1, 3, 308, 308))).data
    img_pil = transforms.ToPILImage()(img)
    img_pil.save("test.png")


if __name__ == "__main__":
    main()

运行此程序将提供以下内容:

(mgan-Csuh5VLx) ➜  mgan git:(broken) ✗ python test.py
Traceback (most recent call last):
  File "test.py", line 34, in <module>
    main()
  File "test.py", line 30, in main
    img_pil = transforms.ToPILImage()(img)
  File "/home/christopher/.local/share/virtualenvs/mgan-Csuh5VLx/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 92, in __call__
    return F.to_pil_image(pic, self.mode)
  File "/home/christopher/.local/share/virtualenvs/mgan-Csuh5VLx/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 96, in to_pil_image
    raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
TypeError: pic should be Tensor or ndarray. Got <class 'torch.FloatTensor'>.

我认为FloatTensor s基本上是Tensor s。有没有办法解决这个错误?

(注意:RGBA有四个输出通道,但即使切换到3个输出通道也会产生相同的错误。)

1 个答案:

答案 0 :(得分:2)

只需更改main函数的倒数第二行即可解决问题:

img_pil = transforms.ToPILImage()(img.squeeze())

img.squeeze()使张量形状(1, 4, 300, 300)变为(4, 300, 300)