虽然这个例子没有经过培训,但这是一个大型项目的改编版,其中培训确实发生了。我只是希望generator
网络在这种情况下喷出随机图像:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
class Generator(nn.Module):
def __init__(self):
"""
Generator component of GAN. requires an input slightly bigger
than 300 x 300 (precisely 308 x 308)
"""
super(Generator, self).__init__()
# 5 x 5 square convolution.
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 4, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
def main():
# Generate example image.
generator = Generator()
img = generator(Variable(torch.randn(1, 3, 308, 308))).data
img_pil = transforms.ToPILImage()(img)
img_pil.save("test.png")
if __name__ == "__main__":
main()
运行此程序将提供以下内容:
(mgan-Csuh5VLx) ➜ mgan git:(broken) ✗ python test.py
Traceback (most recent call last):
File "test.py", line 34, in <module>
main()
File "test.py", line 30, in main
img_pil = transforms.ToPILImage()(img)
File "/home/christopher/.local/share/virtualenvs/mgan-Csuh5VLx/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 92, in __call__
return F.to_pil_image(pic, self.mode)
File "/home/christopher/.local/share/virtualenvs/mgan-Csuh5VLx/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 96, in to_pil_image
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
TypeError: pic should be Tensor or ndarray. Got <class 'torch.FloatTensor'>.
我认为FloatTensor
s基本上是Tensor
s。有没有办法解决这个错误?
(注意:RGBA有四个输出通道,但即使切换到3个输出通道也会产生相同的错误。)
答案 0 :(得分:2)
只需更改main
函数的倒数第二行即可解决问题:
img_pil = transforms.ToPILImage()(img.squeeze())
img.squeeze()
使张量形状(1, 4, 300, 300)
变为(4, 300, 300)
。