自定义数据集,数据加载器,采样器,还是其他?

时间:2020-05-18 06:23:34

标签: python pytorch dataloader

我正在从事一个需要在非常大的图像数据集上训练PyTorch框架NN的项目。这些图像中的某些图像与问题完全不相关,但是这些不相关的图像并未这样标记。但是,有一些指标可以用来计算它们是否不相关(例如,将所有像素值相加可以使我很好地了解哪些是相关的图像,哪些不是)。我理想地希望做的是拥有一个可以装入Dataset类的Dataloader,并且仅使用相关图像创建批处理。 Dataset类只会知道图像列表及其标签,而Dataloader会解释与之成批处理的图像是否相关,然后仅与相关图像进行批处理。

要将其应用于示例,可以说我有一个黑白图像数据集。白色图像无关紧要,但没有这样标记。我希望能够从文件位置加载批次,并使这些批次仅包含黑色图像。我可以在某个时候通过对所有像素求和并找到等于0的方式进行过滤。

我想知道的是,自定义数据集,数据加载器或采样器是否可以为我解决此任务?我已经编写了一个自定义数据集,用于存储所有已保存图像的目录以及该目录中所有图像的列表,并且可以在 getitem 函数中返回带有标签的图像。我还应该添加一些内容以过滤掉某些图像吗?还是应该在自定义Dataloader或Sampler中应用该过滤器?

谢谢!

2 个答案:

答案 0 :(得分:0)

我假设您的图像数据集属于两个类别(0或1),但未标记。正如@PranayModukuru所提到的,您可以在巡回自定义数据集类的 getitem 函数中使用某种度量(例如,汇总图像的所有像素强度值)来确定相似性。

但是,在训练模型时确定 getitem 函数的相似性会使训练过程非常缓慢。因此,我建议您在开始训练之前先近似相似度(不在 getitem 函数中)。此外,如果您的图像数据集由复杂图像(不是黑白图像)组成,则最好使用预先训练的深度学习模型(例如resnet或自动编码器)来减少维数,然后再应用聚类方法(例如聚类聚类)来标记图像。

在第二种方法中,您只需要标记图像一次即可,如果在训练时在图像上应用增强,则无需在 getitem 中重新确定相似性(标签) >功能。另一方面,在第一种方法中,您需要每次在 getitem 函数中(对图像应用转换之后)确定相似性(标签),这是多余,不必要和耗时的。

希望这会有所帮助。

答案 1 :(得分:0)

听起来您的目标是从训练中完全删除不相关的图像。

处理此问题的最佳方法是预先弄清所有相关图像的文件名,并将其文件名保存到csv或其他内容中。然后仅将良好的文件名传递给数据集。

原因是您将在训练期间多次浏览数据集。这意味着您将一遍又一遍地加载,分析和丢弃不相关的图像,这浪费了计算。

最好预先进行这种预处理/过滤。