高效的PyTorch DataLoader collat​​e_fn函数可用于各种尺寸的输入

时间:2019-01-07 23:34:40

标签: python-3.x machine-learning pytorch mini-batch

我在为PyTorch collate_fn类编写自定义DataLoader函数时遇到麻烦。我需要自定义功能,因为我的输入具有不同的尺寸。

我目前正在尝试编写Stanford MURA paper的基准实现。数据集具有一组标记的研究。一项研究可能包含多个图像。我创建了一个自定义的Dataset类,该类使用torch.stack堆叠了这些多个图像。

然后将堆叠的张量作为模型的输入提供,并且对输出列表进行平均以获得单个输出。当DataLoader时,此实现与batch_size=1配合良好。但是,当我尝试将batch_size设置为8时,就像在原始论文中一样,DataLoader失败了,因为它使用torch.stack来堆叠批次和我批次中的输入具有可变的尺寸(因为每个研究可以包含多个图像)。

为了解决此问题,我尝试实现自定义的collate_fn函数。

def collate_fn(batch):
    imgs = [item['images'] for item in batch]
    targets = [item['label'] for item in batch]
    targets = torch.LongTensor(targets)
    return imgs, targets

然后在训练纪元循环中,我像这样循环遍历每个批次:

for image, label in zip(*batch):
    label = label.type(torch.FloatTensor)
    # wrap them in Variable
    image = Variable(image).cuda()  
    label = Variable(label).cuda()
    # forward
    output = model(image)
    output = torch.mean(output)
    loss = criterion(output, label, phase)

但是,这并没有给我任何时间上改进的计时,并且仍然需要与仅将批处理大小设置为1一样的时间。我还尝试将批处理大小设置为32,但这并没有改善计时要么。

我做错什么了吗? 有更好的方法吗?

0 个答案:

没有答案