通过转换将通道添加到MNIST?

时间:2019-02-15 10:21:48

标签: pytorch mnist python-3.7 torchvision

我正在尝试使用torchvision.datasets中的MNIST数据集。它似乎是作为N x H x W (uint8)(批处理尺寸,高度,宽度)张量提供的。但是,所有用于图像的pytorch类(例如Conv2d)都需要一个N x C x H x W (float32)张量,其中C是颜色通道的数量。我尝试添加添加ToTensor变换,但是没有添加颜色通道。

有没有一种方法可以使用torchvision.transforms添加此附加维度?对于原始的tensor,我们可以做.unsqueeze(1),但这似乎不是一个非常优雅的解决方案。我只是想以“适当”的方式来做。

这是转换失败的地方。

import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])

1 个答案:

答案 0 :(得分:0)

我有一个误解:dataset.train_data不受指定transform的影响,只有DataLoader(dataset,...)的输出会受到影响。从{p>中检查data

for data, _ in DataLoader(dataset):
    break

我们可以看到ToTensor确实可以满足您的要求。