向Pytorch Dataloader添加类对象:批处理必须包含张量

时间:2020-10-29 07:18:33

标签: python pytorch

我有一个自定义的Pytorch数据集,该数据集返回包含类对象“ queries”的字典。

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return self.values.shape[0]

    def __getitem__(self, idx):
        sample = DeviceDict({'query': self.queries[idx],
                             "values": self.values[idx],
                             "targets": self.targets[idx]})
        return sample

问题是,当我将查询放入数据加载器时,我得到default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>。有没有办法在我的数据加载器中有一个类对象?在下面的代码中,它在next(iterator)处爆炸。

train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
                                           batch_size=10],
                                           shuffle=True,
                                           drop_last=False)
for i in range(epochs):
    iterator = iter(train_loader)
    for i in range(len(train_loader)):
        batch = next(iterator)
        out = model(batch)
        loss = criterion(out["pred"], batch["targets"])
        self.optimizer.zero_grad()
        loss.sum().backward()
        self.optimizer.step()

2 个答案:

答案 0 :(得分:1)

您需要定义自己的colate_fn才能做到这一点。 仅仅向您展示事物在这里如何工作的草率方法将是这样的:

import torch
class DeviceDict:
    def __init__(self, data):
        self.data = data 

    def print_data(self):
        print(self.data)

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return 5

    def __getitem__(self, idx):
        sample = {'query': self.queries[idx],
                 "values": self.values[idx],
                 "targets": self.targets[idx]}
        return sample

def custom_collate(dict):
    return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

基本上,colate_fn允许您实现自定义批处理或添加对自定义数据类型的支持,如我先前提供的链接中所述。
如您所见,它只是显示了概念,您需要根据自己的需要对其进行更改。

答案 1 :(得分:1)

对于那些好奇的人来说,这是我用来使事情正常工作的DeviceDict和自定义整理功能。

class DeviceDict(dict):

    def __init__(self, *args):
        super(DeviceDict, self).__init__(*args)

    def to(self, device):
        dd = DeviceDict()
        for k, v in self.items():
            if torch.is_tensor(v):
                dd[k] = v.to(device)
            else:
                dd[k] = v
        return dd


def collate_helper(elems, key):
    if key == "query":
        return elems
    else:
        return torch.utils.data.dataloader.default_collate(elems)


def custom_collate(batch):
    elem = batch[0]
    return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})