我试图从磁盘上的数据中批量找到Pytorch训练的一个例子 - 类似于Keras fit_generator。我如何更改下面的代码从磁盘读取csv而不是将其加载到内存?
我发现可以迭代自定义数据加载器,如下所示,但我不确定如何在不加载内存中的所有数据的情况下执行此操作。
我想:
重复x纪元
class testLoader(Dataset):
def __init__(self):
#regular old numpy
boston = load_boston()
x=boston.data
y=boston.target
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(y)
self.length = x.shape[0]
self.vars =x.shape[1]
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.length
training_samples=testLoader()
train_loader = utils_data.DataLoader(training_samples, batch_size=64, shuffle=True)
答案 0 :(得分:0)
我将从这两个教程开始:
pytorch网站上有一些非常容易理解的教程。我真的可以建议您仔细阅读它们,以了解框架的工作原理。 https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
斯坦福大学的一些教程可以很好地补充pytorch。我可以推荐这个: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel