我有一个自定义的Pytorch数据集,该数据集返回包含类对象“ queries”的字典。
class QueryDataset(torch.utils.data.Dataset):
def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets
def __len__(self):
return self.values.shape[0]
def __getitem__(self, idx):
sample = DeviceDict({'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]})
return sample
问题是,当我将查询放入数据加载器时,我得到default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>
。有没有办法在我的数据加载器中有一个类对象?在下面的代码中,它在next(iterator)
处爆炸。
train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
batch_size=10],
shuffle=True,
drop_last=False)
for i in range(epochs):
iterator = iter(train_loader)
for i in range(len(train_loader)):
batch = next(iterator)
out = model(batch)
loss = criterion(out["pred"], batch["targets"])
self.optimizer.zero_grad()
loss.sum().backward()
self.optimizer.step()
答案 0 :(得分:1)
您需要定义自己的colate_fn才能做到这一点。 仅仅向您展示事物在这里如何工作的草率方法将是这样的:
import torch
class DeviceDict:
def __init__(self, data):
self.data = data
def print_data(self):
print(self.data)
class QueryDataset(torch.utils.data.Dataset):
def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets
def __len__(self):
return 5
def __getitem__(self, idx):
sample = {'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]}
return sample
def custom_collate(dict):
return DeviceDict(dict)
dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()
基本上,colate_fn
允许您实现自定义批处理或添加对自定义数据类型的支持,如我先前提供的链接中所述。
如您所见,它只是显示了概念,您需要根据自己的需要对其进行更改。
答案 1 :(得分:1)
对于那些好奇的人来说,这是我用来使事情正常工作的DeviceDict和自定义整理功能。
class DeviceDict(dict):
def __init__(self, *args):
super(DeviceDict, self).__init__(*args)
def to(self, device):
dd = DeviceDict()
for k, v in self.items():
if torch.is_tensor(v):
dd[k] = v.to(device)
else:
dd[k] = v
return dd
def collate_helper(elems, key):
if key == "query":
return elems
else:
return torch.utils.data.dataloader.default_collate(elems)
def custom_collate(batch):
elem = batch[0]
return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})