将Pytorch ImageFolder数据集与自定义Pytorch数据集结合

时间:2020-06-09 17:49:36

标签: python pytorch

我正在使用Pytorch从多个数据集中加载数据。我将一些图像存储在标签正确的文件夹中(例如\ 0和\ 1),在这种情况下,例如,在加载列表后,我可以使用torch.utils.data.ConcatDataset(其中trans是一组预定义的Pytorch转换):

l = []
l.append(datasets.ImageFolder(file_path, trans))
l.append(datasets.ImageFolder(file_path2, trans))
image_datasets = torch.utils.data.ConcatDataset(l)

img_datasets = dict()
img_datasets['train'], img_datasets['val'] = torch.utils.data.random_split(image_datasets, (round(0.8*len(image_datasets)), round(0.2*len(image_datasets)) ))

但是,我也正在使用csv文件从其他不同的位置加载图像。这里的过程看起来像这样:

class MyData(Dataset):
  def __init__(self, df):
      self.df = df

  def __len__(self):
      return self.df.shape[0]

  def __getitem__(self, index):
      image = trans(PIL.Image.open(self.df.file_path[index]))
      label = self.df.label[index]

      return image, label


df = pd.read_csv(image_file_paths), names=["file_path", "label"])
mydata = MyData(df)

my_datasets = dict()
my_datasets['train'], my_datasets['val'] = torch.utils.data.random_split(mydata, (round(0.8*len(mydata)), round(0.2*len(mydata))))

因此,我希望能够将这些数据集合并到一个数据加载器中。关于我应该如何做的任何想法?谢谢!

1 个答案:

答案 0 :(得分:0)

找到了解决方案;只需使用ConcatDataset的多次通过:

l = []
l.append(datasets.ImageFolder(file_path, trans))
l.append(datasets.ImageFolder(file_path2, trans))
image_datasets = torch.utils.data.ConcatDataset(l)

df = pd.read_csv(image_file_paths), names=["file_path", "label"])
mydata = MyData(df)

image_datasets = torch.utils.data.ConcatDataset([image_datasets, mydata])

img_datasets = dict()
img_datasets['train'], img_datasets['val'] = torch.utils.data.random_split(image_datasets, (round(0.8*len(image_datasets)), round(0.2*len(image_datasets))))

很高兴从那里走。