我有一个庞大的数据集,其特征(input_id,input_mask,segment_id,label_id)已成批保存在咸菜文件中,共64批。我阅读了此文件,创建了一个TensorDataset并传递给数据加载器进行培训。由于功能文件太大而无法创建完整的TensorDataset,因此我想将TensorDataset转换为IterableDataset,以便可以一次从功能文件中检索一批样本并将其传递给数据加载器。但是在训练时,出现以下错误:
TypeError: iter() returned non-iterator of type 'TensorDataset'
以下是我编写的自定义数据集类:
class MyDataset(IterableDataset):
def __init__(self,args):
self.args=args
def get_features(self,filename):
with open(filename, "rb") as f:
while True:
try:
yield pickle.load(f)
except EOFError:
break
def process(self,args):
if args.cached_features_file:
cached_features_file = args.cached_features_file
if os.path.exists(cached_features_file):
features=self.get_features(cached_features_file)
feat = next (features)
li=list(feat)
all_input_ids=torch.tensor([f.input_ids for f in li ], dtype=torch.long)
all_input_mask= torch.tensor([f.input_mask for f in li ], dtype=torch.long)
all_segment_ids= torch.tensor([f.segment_ids for f in li], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in li ], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
return dataset
def __iter__(self):
dataset=self.process(self.args)
return dataset
我这样使用它:
train_dataset=MyDataset(args)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
我知道TensorDataset是需要索引的地图样式,而IterableDataset是iterable样式,这就是错误的原因。即使我返回特征张量的列表/元组而不是TensorDataset,我也会遇到类似的错误。有人可以告诉我如何使用IterableDataset以正确的方式加载批处理数据集吗?
答案 0 :(得分:1)
我通过以其他方式保存数据集解决了该问题。我将这些功能保存为在腌制文件中逐渐腌制的字典对象,然后一次将其读取一次,然后传递给数据加载器进行处理。批处理由数据加载器自动完成。自定义类现在的外观如下:
class MyDataset(IterableDataset):
def __init__(self,filename):
self.filename=filename
super().__init__()
def process(self,filename):
with open(filename, "rb") as f:
while True:
try:
yield pickle.load(f)
except EOFError:
break
def __iter__(self):
dataset=self.process(self.filename)
return dataset