我们都知道torchvision.datasets
包中包含的公共MNIST数据集。想象一下,我想创建一个仅包含 1 和 0 的数据集的简化版本,以便仅对这两个数字而不是所有10个值进行分类。
我已经看到可以在继承所需数据集的类中创建自定义数据集,因此__getitem__
可以在给定索引处返回项目。所以我做到了:
class MNIST01(MNIST):
def __getitem__(self, idx):
image, label = super().__getitem__(idx)
if label.item() <= 1:
return image, label
else:
return None
问题在于,似乎我无法返回None值,因为它必须是“包含张量,数字,字典或列表;找到的类为'NoneType'”。
是否有一种简单的方法可以以类似的方式轻松获得此数据集的简化版本?
答案 0 :(得分:0)
我终于设法解决了NoneType问题。保持问题中定义的功能。
class MNIST01(MNIST):
def __getitem__(self, idx):
features, target = super(MNIST01, self).__getitem__(idx)
if target.item() <= 1:
return features, target
我们现在需要为数据加载器定义一个自定义collate function collate_fn
,该数据加载器将处理样本列表以形成一个批处理。在此函数中,我们可以应用过滤器来处理None
值,并忽略它们。
from torch.utils.data.dataloader import default_collate
def filter_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
然后我们只需要将此函数传递给DataLoader
:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)
版本2
比第一个简单得多,避免了访问数据时的一些问题。只需从train_data
类的实例中直接过滤train_label
和MNIST
属性(并对应于测试集)即可。
train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]