自定义数据集和数据加载器

时间:2021-01-27 14:02:48

标签: dataframe computer-vision pytorch dataloader

我是 pytorch 的新手。 我的大数据集由两个 txt 文件组成,一个用于数据,另一个用于目标数据。 在训练文件中每行是长度为 340 的列表,在目标中每行是长度为 136 的列表。

我想问一下如何定义我的数据集,以便我可以使用 Dataloader 加载我的数据来训练 pytorch 模型?

我希望你回答

1 个答案:

答案 0 :(得分:0)

Dataset 中的

torch.utils.data 是表示数据集的抽象类。您的自定义数据集应继承 Dataset 并覆盖以下方法:

__len__() 使 len(dataset) 返回数据集的大小。
__getitem__() 支持索引,使得 dataset[i] 可用于获取第 i 个样本

例如编写自定义数据集
我已经为您编写了一个通用的自定义数据加载器作为您的问题陈述。
这里 data.txt 有数据,label.txt 有标签。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        
       
        with open('data.txt', 'r') as f:
                self.data_info = f.readlines()
        
        with open('label.txt', 'r') as f:
                self.label_info = f.readlines()        


    def __getitem__(self, index):
        
        single_data = self.data_info[index].rstrip('\n')
        

        single_label = self.label_info[index].rstrip('\n')

        return ( single_data , single_label)

    def __len__(self):
        return len(self.data_info)
# Testing 
d = CustomDataset()
print(d[1]) # should output data along with label

这将是您案例的基础,但必须进行一些与您的案例相匹配的更改。

注意:您必须根据数据集进行必要的更改

相关问题