我正在处理将图像分为10类的图像数据集(CIFAR10数据集)。我正在使用PyTorch。请,我想知道如何通过遍历数据集来确定每个类的图像数量。预先感谢您的回复。
答案 0 :(得分:1)
您可以进行两次操作。
下一步遍历数据集,并根据img_dict中的类键继续增加值
在此处输入代码
dataset_size = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)
img_dict = {}
for i in range(num_classes):
img_dict[classes[i]] = 0
for i in range(dataset_size):
img, label = dataset[i]
img_dict[classes[label]] += 1
img_dict
您将得到如下输出: