当我试图找出torchvision.datasets.cifar.CIFAR10里面的内容时,我做了一些简单的代码
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))
然而,当我尝试
时出现了一些错误print(trainset[:10])
错误信息
TypeError: Cannot handle this data type
我想知道为什么我可以使用trainset[1]
,而不是trainset[:10]
?
答案 0 :(得分:0)
CIFAR10不支持切片,这就是您收到该错误的原因。如果你想要前10个,你将不得不这样做:
class Vector(object):
def __init__(self, coordinates):
try:
if not coordinates:
raise ValueError
self.coordinates = tuple(coordinates)
self.dimension = len(coordinates)
def plus(self, v):
new_coordinates = [x + y for x, y in zip(self.coordinates, v.coordinates)]
return Vector(new_coordinates)
def __str__(self):
return 'Vector: {}'.format(self.coordinates)
def __eq__(self, v):
return self.coordinates == v.coordinates
v = Vector([8.218, -9.341])
w = Vector([-1,129, 2.111])
print (v.plus(w))
您可以索引CIFAR10类实例的主要原因是该类实现了print([trainset[i] for i in range(10)])
函数。
因此,当您致电__getitem__()
时,您实际上是在呼叫trainset[i]
现在,在python3中,切片表达式也通过trainset.__getitem__(i)
处理,其中切片表达式作为切片对象传递给__getitem__()
。
因此,__getitem__()
相当于trainset[2:10]
由于将两种不同类型的对象传递给trainset.__getitem__(slice(2, 10))
我们应该做完全不同的事情,你必须明确地处理它们。
不幸的是,正如你可以从CIFAR10类的__getitem__
方法实现中看到的那样:
__getitem__
答案 1 :(得分:0)
除了对https://stackoverflow.com/a/45226879/7924573的启发,我建议使用 torch.utils.data.dataset.random_split ,例如这种方式:
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)
validloader = DataLoader(valid_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)