我正在尝试使用torchvision.datasets
中的MNIST数据集。它似乎是作为N x H x W (uint8)
(批处理尺寸,高度,宽度)张量提供的。但是,所有用于图像的pytorch类(例如Conv2d
)都需要一个N x C x H x W (float32)
张量,其中C
是颜色通道的数量。我尝试添加添加ToTensor
变换,但是没有添加颜色通道。
有没有一种方法可以使用torchvision.transforms
添加此附加维度?对于原始的tensor
,我们可以做.unsqueeze(1)
,但这似乎不是一个非常优雅的解决方案。我只是想以“适当”的方式来做。
这是转换失败的地方。
import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])
答案 0 :(得分:0)
我有一个误解:dataset.train_data
不受指定transform
的影响,只有DataLoader(dataset,...)
的输出会受到影响。从{p>中检查data
后
for data, _ in DataLoader(dataset):
break
我们可以看到ToTensor
确实可以满足您的要求。