pytorch:获取给定数据集的类数

时间:2019-03-19 07:27:51

标签: python machine-learning deep-learning pytorch

如果我有一个像这样的数据集:

image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)

如何以编程方式确定数据集中的类或唯一标签的数量?

2 个答案:

答案 0 :(得分:1)

如果您的数据类型是张量,则可以使用:

import torch n_classes = len(torch.unique(Your_Target_Vector))

答案 1 :(得分:-1)

使用:

len(image_datasets['train'].classes)

.classes返回一个列表。