使用DataLoader

时间:2019-02-17 20:52:55

标签: python machine-learning pytorch

我有一个.csv文件,格式如下:

x1, x2, y, x3, x4
05  05  0  00  12
01  09  1  00  17
       ...
13  24  0  01  00

我创建了下面的类,该类将帮助我用DataLoader加载数据,以便对其进行混洗和批量读取,以及应用我可能想要的任何可能的转换(例如,规范化)最终目的是创建一个简单的RNN。

class MyDataset(Dataset):    
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        m_data = self.data_frame.iloc[idx, :].as_matrix()
        m_data = m_data.astype('float')
        sample = {'y': np.array([m_data[2]]), 'x': np.delete(m_data, 2)}

        if self.transform:
            sample = self.transform(sample)

        return sample

然后我可以用类似的方式加载数据

dataset = MyDataset(csv_file=file, transform=transforms.Compose([ToTensor()]))    
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

这是PyTorch的{​​{3}}教程之后的内容,它用于加载图像数据集。

我的问题是,由于我拥有完全不同的数据集,因此我对我的情况很敏感,还是我应该以其他方式更有效地加载和处理数据?

0 个答案:

没有答案