Pytorch:如何使用参考表创建自定义数据集

时间:2019-06-21 02:33:06

标签: python pytorch torchvision

我有一个reference.csv文件,该文件具有三列:Type,Class和Path。这是前5个示例行:

"Type","Class","Path"
"train","A","./path1/001.jpg"
"train","A","./path2/002.jpg"
"test","C","./path3/003.jpg"
"train","B","./path4/001.jpg"
"test","B","./path5/002.jpg"
...

以一种对观众更友好的格式:

|----------------------|------------------|------------------|
|         Type         |       Class      |       Path       |
|----------------------|------------------|------------------|
|        train         |         A        | ./path1/001.jpg  |
|----------------------|------------------|------------------|
|        train         |         A        | ./path2/002.jpg  |
|----------------------|------------------|------------------|
|        train         |         C        | ./path3/003.jpg  |
|----------------------|------------------|------------------|
|        test          |         B        | ./path4/001.jpg  |
|----------------------|------------------|------------------|
|        test          |         B        | ./path5/002.jpg  |
|----------------------|------------------|------------------|

我想创建一个数据集类(torch.utils.data.Dataset)来读取图像,以便可以使用DataLoader(torch.utils.data.DataLoader)。

使用参考表创建自定义数据集的正确方法是什么?

1 个答案:

答案 0 :(得分:1)

如果我们要构建一个自定义数据集来读取此csv文件中的图像位置,则可以执行以下操作。您的逻辑可能会有所不同。

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()

        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)

        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])

        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])

        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])

        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):

        # Get image name from the pandas df
        single_image_name = self.image_arr[index]

        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        some_operation = self.operation_arr[index]

        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    custom_mnist_from_images =  \
        CustomDatasetFromImages('../data/mnist_labels.csv')
相关问题