来自文件路径和标签的csv的Pytorch数据加载器

时间:2020-06-08 21:06:49

标签: python pytorch

我有一个用于火车和测试数据集的csv文件,其中包含文件位置和标签。该数据框的开头是:

df.head()
Out[46]: 
             file_path  label
0  \\images\\29771.png      0
1  \\images\\55201.png      0
2  \\images\\00715.png      1
3  \\images\\33214.png      0
4  \\images\\99841.png      1

我在文件路径中有多个位置,而且空间有限,因此无法将它们复制到\ 0和\ 1文件夹位置。如何使用此数据框创建pytorch数据加载器和/或数据集对象?

1 个答案:

答案 0 :(得分:2)

只需为数据集编写自定义__getitem__方法即可。

class MyData(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        image = load_image(self.df.file_path[index])
        label = self.df.label[index]

        return image, label

load_image的功能是将文件名读取为所需的任何格式。