PyTorch数据集中每个类的图像数

时间:2020-06-08 16:13:08

标签: pytorch

我正在处理将图像分为10类的图像数据集(CIFAR10数据集)。我正在使用PyTorch。请,我想知道如何通过遍历数据集来确定每个类的图像数量。预先感谢您的回复。

1 个答案:

答案 0 :(得分:1)

您可以进行两次操作。

  • 首先创建一个字典img_dict,其中包含CIFAR10数据集的所有类。将所有值初始化为0。
  • 下一步遍历数据集,并根据img_dict中的类键继续增加值

    在此处输入代码

dataset_size = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)
img_dict = {}
for i in range(num_classes):
    img_dict[classes[i]] = 0

for i in range(dataset_size):
    img, label = dataset[i]
    img_dict[classes[label]] += 1

img_dict

您将得到如下输出:

number of images per class:https://imgur.com/a/u4BcnDW