我有一个.csv
文件,格式如下:
x1, x2, y, x3, x4
05 05 0 00 12
01 09 1 00 17
...
13 24 0 01 00
我创建了下面的类,该类将帮助我用DataLoader
加载数据,以便对其进行混洗和批量读取,以及应用我可能想要的任何可能的转换(例如,规范化)最终目的是创建一个简单的RNN。
class MyDataset(Dataset):
def __init__(self, csv_file, transform=None):
self.data_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
m_data = self.data_frame.iloc[idx, :].as_matrix()
m_data = m_data.astype('float')
sample = {'y': np.array([m_data[2]]), 'x': np.delete(m_data, 2)}
if self.transform:
sample = self.transform(sample)
return sample
然后我可以用类似的方式加载数据
dataset = MyDataset(csv_file=file, transform=transforms.Compose([ToTensor()]))
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
这是PyTorch的{{3}}教程之后的内容,它用于加载图像数据集。
我的问题是,由于我拥有完全不同的数据集,因此我对我的情况很敏感,还是我应该以其他方式更有效地加载和处理数据?