如何从pytorch DataLoader获取特定样本?

时间:2020-07-06 15:25:13

标签: pytorch

在Pytorch中,是否可以使用torch.utils.data.DataLoader类加载特定于 的单个样本?我想用它做一些测试。

tutorial使用

trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))

提取一批 random 样本。有没有办法使用DataLoader来获取特定于 的示例?

欢呼

2 个答案:

答案 0 :(得分:1)

  • 关闭types中的auto
  • 使用shuffle计算您要查找的所需样品所属的批次
  • 迭代到所需的批次

代码

DataLoader

输出:

batch_size

答案 1 :(得分:0)

如果您想从数据集中获取特定的信号样本,可以
您应该检查子集类。(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) 像这样的东西:

indices =  [0,1,2]  # select your indices here as a list  
subset = torch.utils.data.Subset(train_set, indices)
trainloader = DataLoader(subset , batch_size =  16  , shuffle =False) #set shuffle to False 

for image , label in trainloader:
   print(image.size() , '\t' , label.size())
   print(image[0], '\t' , label[0]) # index the specific sample 

如果您想了解有关Pytorch数据加载实用程序的更多信息,这是一个有用的链接 (https://pytorch.org/docs/stable/data.html

相关问题