PyTorch数据集/数据加载器批处理

时间:2020-06-19 00:46:37

标签: pytorch

对于在时间序列数据上实现PyTorch数据管道的“最佳实践”,我有些困惑。

我有一个使用自定义DataLoader读取的HD5文件。似乎我应该将数据样本作为一个(特征,目标)元组返回,每个元组的形状为(L,C),其中L是seq_len,C是通道数-即,不要在数据加载器中进行批处理,只是作为桌子返回。

PyTorch模块似乎需要批量变暗,即Conv1D期望为(N,C,L)。

我的印象是DataLoader类将在批处理维度之前,但事实并非如此,我的数据形状为(N,L)。

dataset = HD5Dataset(args.dataset)

dataloader = DataLoader(dataset,
                        batch_size=N,
                        shuffle=True,
                        pin_memory=is_cuda,
                        num_workers=num_workers)

for i, (x, y) in enumerate(train_dataloader):
    ...

在上面的代码中,x的形状是(N,C)而不是(1,N,C),这导致下面的代码(来自公共git repo)在第一行失败。

def forward(self, x):
    """expected input shape is (N, L, C)"""
    x = x.transpose(1, 2).contiguous() # input should have dimension (N, C, L)

文档指出启用自动批处理 它总是将新维度添加为批处理维度,这使我相信自动批处理已禁用但我不明白为什么?

2 个答案:

答案 0 :(得分:0)

我发现了一些似乎可行的方法,一种选择似乎是使用DataLoader的collate_fn,但更简单的选择是使用BatchSampler,即

dataset = HD5Dataset(args.dataset)
train, test = train_test_split(list(range(len(dataset))), test_size=.1)

train_dataloader = DataLoader(dataset,
                        pin_memory=is_cuda,
                        num_workers=num_workers,
                        sampler=BatchSampler(SequentialSampler(train),batch_size=len(train), drop_last=True)
                        )

test_dataloader = DataLoader(dataset,
                        pin_memory=is_cuda,
                        num_workers=num_workers,
                        sampler=BatchSampler(SequentialSampler(test),batch_size=len(test), drop_last=True)
                        )

for i, (x, y) in enumerate(train_dataloader):
    print (x,y)

这会将数据集dim(L,C)转换为单批(1,L,C)(效率不高)。

答案 1 :(得分:0)

如果您有一对张量 (x, y) 的数据集,其中每个 x 的形状为 (C,L),则:

N, C, L = 5, 3, 10
dataset = [(torch.randn(C,L), torch.ones(1)) for i in range(50)]
dataloader = data_utils.DataLoader(dataset, batch_size=N)

for i, (x,y) in enumerate(dataloader):
    print(x.shape)

将为 (N,C,L) 生产 (50/N)=10 批形状 x

torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])
torch.Size([5, 3, 10])