单次从数据集中采样每幅图像N次

时间:2019-05-04 20:37:12

标签: python neural-network computer-vision pytorch distributed-computing

我目前正在研究学习表示形式(深层嵌入)的任务。我使用的数据集每个对象只有一个示例图像。我也使用增强。

在训练期间,每个批次必须在数据集中包含N个不同的单个图像增强版本(dataset[index]总是返回新的随机变换)。

是否有一些标准的解决方案或带有DataLoader的库可以用于torch.utils.data.distributed.DistributedSampler? 如果没有,从torch.utils.data.DataLoader继承(并调用super().__init__(...))的任何DataLoader都可以在分布式培训中工作吗?

1 个答案:

答案 0 :(得分:0)

据我所知,这不是一种标准的处理方式-即使每个对象只有一个样本,每个批次仍然会从不同的对象中采样不同的图像,并且在不同的时期内采样的图像将是改变了。

也就是说,如果您真的想做自己想做的事情,为什么不简单地编写数据集的包装?

class Wrapper(Dataset):
    N = 16
    def __getitem__(self, index):
        sample = [ super().__getitem__(index) for _ in N ]
        sample = torch.stack(sample, dim=0)
        return sample

那么每个批次将是BxNxCxHxW,其中B是批次大小,N是您的重复次数。从数据加载器中获取批次后,您可以重新调整其形状。