我正在搜索创建具有特定格式的个人数据加载器以使用Pytorch库,有人知道我该怎么办?我已经遵循了Pytorch教程,但找不到答案!
我需要一个DataLoader,它可以产生以下格式的元组: (Bx3xHxW FloatTensor x,BxHxW LongTensor y,BxN LongTensor y_cls),其中 x-一批输入图像, y-一批groung真实seg地图, y_cls-一维维度N的一维张量的批处理:N个类的总数, 如果图像i中存在类T,则y_cls [i,T] = 1,否则为0
我希望有人可以解决问题..::)谢谢!
答案 0 :(得分:1)
您只需要拥有一个从torch.utils.data.Dataset
派生的数据库,其中__getitem__(index)
返回所需类型的元组(x, y, y_cls)
,pytorch会处理其他所有事情。
from torch.utils import data
class MyTupleDataset(data.Dataset):
def __init__(self):
super(MyTupleDataset, self).__init__()
# init your dataset here...
def __getitem__(index):
x = torch.Tensor(3, H, W) # batch dim is handled by the data loader
y = torch.Tensor(H, W).to(torch.long)
y_cls = torch.Tensor(N).to(torch.long)
return x, y, y_cls
就是这样。为pytorch的{{3}}提供MyTupleDataset
,您就完成了。