我正在使用Pytorch从多个数据集中加载数据。我将一些图像存储在标签正确的文件夹中(例如\ 0和\ 1),在这种情况下,例如,在加载列表后,我可以使用torch.utils.data.ConcatDataset(其中trans是一组预定义的Pytorch转换):
l = []
l.append(datasets.ImageFolder(file_path, trans))
l.append(datasets.ImageFolder(file_path2, trans))
image_datasets = torch.utils.data.ConcatDataset(l)
img_datasets = dict()
img_datasets['train'], img_datasets['val'] = torch.utils.data.random_split(image_datasets, (round(0.8*len(image_datasets)), round(0.2*len(image_datasets)) ))
但是,我也正在使用csv文件从其他不同的位置加载图像。这里的过程看起来像这样:
class MyData(Dataset):
def __init__(self, df):
self.df = df
def __len__(self):
return self.df.shape[0]
def __getitem__(self, index):
image = trans(PIL.Image.open(self.df.file_path[index]))
label = self.df.label[index]
return image, label
df = pd.read_csv(image_file_paths), names=["file_path", "label"])
mydata = MyData(df)
my_datasets = dict()
my_datasets['train'], my_datasets['val'] = torch.utils.data.random_split(mydata, (round(0.8*len(mydata)), round(0.2*len(mydata))))
因此,我希望能够将这些数据集合并到一个数据加载器中。关于我应该如何做的任何想法?谢谢!
答案 0 :(得分:0)
找到了解决方案;只需使用ConcatDataset的多次通过:
l = []
l.append(datasets.ImageFolder(file_path, trans))
l.append(datasets.ImageFolder(file_path2, trans))
image_datasets = torch.utils.data.ConcatDataset(l)
df = pd.read_csv(image_file_paths), names=["file_path", "label"])
mydata = MyData(df)
image_datasets = torch.utils.data.ConcatDataset([image_datasets, mydata])
img_datasets = dict()
img_datasets['train'], img_datasets['val'] = torch.utils.data.random_split(image_datasets, (round(0.8*len(image_datasets)), round(0.2*len(image_datasets))))
很高兴从那里走。