torchvision.datasets.cifar.CIFAR10列表与否?

时间:2017-07-20 21:42:54

标签: machine-learning pytorch

当我试图找出torchvision.datasets.cifar.CIFAR10里面的内容时,我做了一些简单的代码

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                    download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))

然而,当我尝试

时出现了一些错误
print(trainset[:10])

错误信息

TypeError: Cannot handle this data type

我想知道为什么我可以使用trainset[1],而不是trainset[:10]

2 个答案:

答案 0 :(得分:0)

CIFAR10不支持切片,这就是您收到该错误的原因。如果你想要前10个,你将不得不这样做:

class Vector(object):

def __init__(self, coordinates):
    try:
        if not coordinates:
            raise ValueError
        self.coordinates = tuple(coordinates)
        self.dimension = len(coordinates)

def plus(self, v):
    new_coordinates = [x + y for x, y in zip(self.coordinates, v.coordinates)]
    return Vector(new_coordinates)

def __str__(self):
    return 'Vector: {}'.format(self.coordinates)


def __eq__(self, v):
    return self.coordinates == v.coordinates

v = Vector([8.218, -9.341])
w = Vector([-1,129, 2.111])
print (v.plus(w))

更多信息

您可以索引CIFAR10类实例的主要原因是该类实现了print([trainset[i] for i in range(10)]) 函数。

因此,当您致电__getitem__()时,您实际上是在呼叫trainset[i]

现在,在python3中,切片表达式也通过trainset.__getitem__(i)处理,其中切片表达式作为切片对象传递给__getitem__()

因此,__getitem__()相当于trainset[2:10]

由于将两种不同类型的对象传递给trainset.__getitem__(slice(2, 10))  我们应该做完全不同的事情,你必须明确地处理它们。

不幸的是,正如你可以从CIFAR10类的__getitem__方法实现中看到的那样:

__getitem__

答案 1 :(得分:0)

除了对https://stackoverflow.com/a/45226879/7924573的启发,我建议使用 torch.utils.data.dataset.random_split ,例如这种方式:

train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)
validloader = DataLoader(valid_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)

来源:https://yimjiyoung.github.io/2020/02/13/How-to-split-dataset-into-train-and-validation-set-in-pytorch/