Pytorch数据不适合内存 - 示例?

时间:2018-02-09 17:42:45

标签: pytorch

我试图从磁盘上的数据中批量找到Pytorch训练的一个例子 - 类似于Keras fit_generator。我如何更改下面的代码从磁盘读取csv而不是将其加载到内存?

我发现可以迭代自定义数据加载器,如下所示,但我不确定如何在不加载内存中的所有数据的情况下执行此操作。

我想:

  • 使用磁盘上保存的数据训练和验证模型
  • 在磁盘上使用小批量的完整数据
  • 重复x纪元

    class testLoader(Dataset):

          def __init__(self):
    
                #regular old numpy
                boston = load_boston()
                x=boston.data
                y=boston.target
    
                self.x = torch.from_numpy(x)
                self.y = torch.from_numpy(y)
    
    
                self.length = x.shape[0]
                self.vars =x.shape[1]
    
          def __getitem__(self, index):
    
              return self.x[index], self.y[index]
    
          def __len__(self):
              return self.length
    
    
    training_samples=testLoader()
    train_loader = utils_data.DataLoader(training_samples, batch_size=64, shuffle=True)
    

1 个答案:

答案 0 :(得分:0)

我将从这两个教程开始:

pytorch网站上有一些非常容易理解的教程。我真的可以建议您仔细阅读它们,以了解框架的工作原理。 https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

斯坦福大学的一些教程可以很好地补充pytorch。我可以推荐这个: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel