使用IterableDataset加载巨大的自定义数据集

时间:2020-08-07 09:06:42

标签: python pytorch

我有一个庞大的数据集,其特征(input_id,input_mask,segment_id,label_id)已成批保存在咸菜文件中,共64批。我阅读了此文件,创建了一个TensorDataset并传递给数据加载器进行培训。由于功能文件太大而无法创建完整的TensorDataset,因此我想将TensorDataset转换为IterableDataset,以便可以一次从功能文件中检索一批样本并将其传递给数据加载器。但是在训练时,出现以下错误: TypeError: iter() returned non-iterator of type 'TensorDataset'

以下是我编写的自定义数据集类:

class MyDataset(IterableDataset):

    def __init__(self,args):
        self.args=args
       
    def get_features(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break  
                    
    def process(self,args):
        if args.cached_features_file:
            cached_features_file = args.cached_features_file

        if os.path.exists(cached_features_file):
            features=self.get_features(cached_features_file)

        feat = next (features)
        li=list(feat)
        all_input_ids=torch.tensor([f.input_ids for f in li ], dtype=torch.long)
        all_input_mask= torch.tensor([f.input_mask for f in li ], dtype=torch.long)
        all_segment_ids= torch.tensor([f.segment_ids for f in li], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in li ], dtype=torch.long)
        
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        return dataset
      
    def __iter__(self):
        dataset=self.process(self.args)       
        return dataset

我这样使用它:

train_dataset=MyDataset(args)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)

我知道TensorDataset是需要索引的地图样式,而IterableDataset是iterable样式,这就是错误的原因。即使我返回特征张量的列表/元组而不是TensorDataset,我也会遇到类似的错误。有人可以告诉我如何使用IterableDataset以正确的方式加载批处理数据集吗?

1 个答案:

答案 0 :(得分:1)

我通过以其他方式保存数据集解决了该问题。我将这些功能保存为在腌制文件中逐渐腌制的字典对象,然后一次将其读取一次,然后传递给数据加载器进行处理。批处理由数据加载器自动完成。自定义类现在的外观如下:

class MyDataset(IterableDataset):

    def __init__(self,filename):
     
        self.filename=filename
        super().__init__()
                    
    def process(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break

    def __iter__(self):
        dataset=self.process(self.filename)          
        return dataset