如何在pytorch中对数据集进行排序

时间:2019-03-03 00:07:34

标签: python machine-learning pytorch mnist

我想按标签中的数值对数据集进行排序。

pytorch有功能可以有效地处理此问题吗?

我的数据集type()来自:

 <class 'torchvision.datasets.mnist.MNIST'>

1 个答案:

答案 0 :(得分:0)

没有有效的通用方法,因为数据集类仅实现__getitem____len__方法,并且不一定具有有关标签的任何“存储”信息。 / p>

但是对于MNIST dataset类,您可以从标签列表中对数据集进行排序。

例如,当您要列出标签为5的索引时。

mnist = torchvision.datasets.mnist.MNIST("/")
labels = mnist.train_labels
fives = (labels == 5).nonzero()