使用PYTORCH制作个人数据加载器

时间:2018-08-06 07:36:49

标签: image-processing deep-learning computer-vision pytorch image-segmentation

我正在搜索创建具有特定格式的个人数据加载器以使用Pytorch库,有人知道我该怎么办?我已经遵循了Pytorch教程,但找不到答案!

我需要一个DataLoader,它可以产生以下格式的元组:         (Bx3xHxW FloatTensor x,BxHxW LongTensor y,BxN LongTensor y_cls),其中         x-一批输入图像,         y-一批groung真实seg地图,         y_cls-一维维度N的一维张量的批处理:N个类的总数,         如果图像i中存在类T,则y_cls [i,T] = 1,否则为0

我希望有人可以解决问题..::)谢谢!

1 个答案:

答案 0 :(得分:1)

您只需要拥有一个从torch.utils.data.Dataset派生的数据库,其中__getitem__(index)返回所需类型的元组(x, y, y_cls),pytorch会处理其他所有事情。

from torch.utils import data

class MyTupleDataset(data.Dataset):
  def __init__(self):
    super(MyTupleDataset, self).__init__()
    # init your dataset here...

  def __getitem__(index):
    x = torch.Tensor(3, H, W)  # batch dim is handled by the data loader
    y = torch.Tensor(H, W).to(torch.long)
    y_cls = torch.Tensor(N).to(torch.long)
    return x, y, y_cls

就是这样。为pytorch的{​​{3}}提供MyTupleDataset,您就完成了。