pytorch:如何在多个文件夹中加载csv文件

时间:2020-07-29 13:49:28

标签: python pytorch

我的数据如下:

folder 1
  part0001.csv
  part0002.csv
  ...
  part0199.csv
folder 2
  part0001.csv
  part0002.csv
  ...
  part0199.csv
folder 3
  part0001.csv
  part0002.csv
  ...
  part0199.csv

更新

每个.csv文件约为100Mb。功能和label都在同一个.csv文件中。每个.csv文件如下。

  feat1 feat2 label
1 1     3     0
2 3     4     1
3 2     5     0
...

我想将样本批量加载到.csv文件中。

1 个答案:

答案 0 :(得分:1)

您必须构建一个加载它们的数据集。 (文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

示例:

import torch
from torch.ults.data import Dataset
import glob2
import pandas as pd

class CustomDataset(Dataset):

    def __init__(self, root)
        self.root = root

        # make a list containing the path to all your csv files
        self.paths =  glob2.glob('src/**/*.csv)
    
    def __len__(self):
        return len(self.paths)

    def __getittem__(idx):

        data = pd.from_csv(self.paths[idx])
        x = data['features']
        y = data['labels']

        return x, y

这是基本操作,您可以对其进行修改以从每个csv文件中抽样随机样本,或者在训练之前对数据进行预处理。

修改

如果您只是与csv一行,那么您可以做三件事。

  1. 对数据进行预处理,然后将其保存为一个大的.csv文件,并将其全部加载到内存中,然后再进行训练。这样可以避免重载文件的麻烦。
  2. (如果由于最终文件而无法使用前一个文件,则将无法容纳在内存中)预处理数据并按数据点将其保存为.csv文件。仍然需要您的数据加载器从光盘读取数据,但是至少您这次将加载较轻的文件。
  3. (如果不能对数据进行预处理,请尽可能保留在内存中,以避免重新加载文件。)

实现前两个解决方案的秘密并不多。解决方案3的代码应如下所示:

import torch
from torch.ults.data import Dataset
import glob2
import pandas as pd

class CustomDataset(Dataset):

    def __init__(self, root)
        self.root = root

        # make a list containing the path to all your csv files
        self.paths =  glob2.glob('src/**/*.csv)
        
        # dict to keep load data in memory:
        self.cache = {}

    
    def __len__(self):
        return len(self.paths)

    def __getittem__(idx):
        """This getittem will load data and save them in memory during training."""
        data = cache.get(idx, None)
        
        if data is None:

            data = pd.from_csv(self.paths[idx])
            
            try:
                # cache data into memory
                self.cache{idx: data}
            except OSError:

                # we may be using too much memory
                del self.cache[list(self.cache.keys())[0]]

        rnd_idx = np.random.randint(len(data))
        x = data['features'][rdn_idx]
        y = data['labels'][rdn_idx]
 
        return x, y