在PyTorch中创建数据集类:具有不同长度的输入

时间:2019-05-21 08:06:52

标签: pytorch

我的PyTorch模型包含两个CNN,然后将其输出合并并通过一系列完全连接的层。两个CNN的输入是矩阵:问题是第一个CNN的形状为128x100,而第二个CNN的形状为128x1000。我现在正在尝试创建一个Dataset类来生成加载程序。目前,我写了以下内容:

class Data(Dataset):

    def __init__(self, dataP, targetP, dataC, targetC, transform=None):
        self.dataP = [torch.from_numpy(X).int() for X in dataP]
        self.targetP = [torch.from_numpy(y).float() for y in targetP]

        self.dataC = [torch.from_numpy(X).int() for X in dataC]
        self.targetC = [torch.from_numpy(y).float() for y in targetC]

        self.transform = transform

    def __getitem__(self, index):
        Xp = self.dataP[index]
        yp = self.targetP[index]

        Xc = self.dataC[index]
        yc = self.targetC[index]

        if self.transform:
            Xp = self.transform(Xp)
            Xc = self.transform(Xc)

        return Xp, yp, Xc, yc

    def __len__(self):
        return len(self.dataP)

虽然代码似乎可以正常运行,但是我可以肯定有问题,因为在__len__方法中,我返回了其中一个输入的长度。可以照顾到不同大小的输入吗?

0 个答案:

没有答案