使用Pytorch数据加载器加载特定样品的简单方法

时间:2019-02-19 18:51:34

标签: python machine-learning deep-learning pytorch

我目前正在使用相对稀疏的标签(大约1%的标签数据中的体素对应于目标类别)来训练用于二进制分类的3D CNN。

为了在培训过程中进行基本的健全性检查(例如,网络是否完全学习?),向网络展示一小部分经过精心挑选的培训示例子集会很方便,这些子集的目标类别标签要高于平均水平

根据Pytorch文档的建议,我实现了自己的dataset类(继承自torch.utils.data.Dataset),该类通过其__get_item__方法向torch.utils.data.DataLoader提供了培训示例。 / p>

在我发现的pytorch tutorials中,DataLoader用作迭代器来生成训练循环,如下所示:

for i, data in enumerate(self.dataloader):

    # Get training data
    inputs, labels = data

    # Train the network
    # [...]

我现在想知道的是,是否存在一种简单的方法来加载单个或几个特定的​​训练示例(使用Dataset的{​​{ 1}}方法)。但是,__get_item__没有DataLoader方法,并且反复调用__get_item__直到达到所需的索引似乎不太好。

显然,解决此问题的一种可能方法是定义从抽象__next__继承的自定义samplerbatch_sampler。但这似乎很容易检索一些特定的样本。

我想我在这里忽略了一些非常简单明显的东西。任何建议表示赞赏!

2 个答案:

答案 0 :(得分:1)

以防万一有类似问题的人在某个时候遇到此问题:

我最终使用的快捷方法是直接访问与训练相关的dataloader属性,从而绕过训练循环中的dataset。假设我们想通过反复展示一个带有线性索引sample_idx(由数据集类定义)的精选的训练示例,来快速检查我们的网络是否完全可以学习。

然后可以执行以下操作:

for i, _ in enumerate(self.dataloader):

    # Get training data
    # inputs, labels = data

    inputs, labels = self.dataloader.dataset[sample_idx]
    inputs = inputs.unsqueeze(0)
    labels = labels.unsqueeze(0)

    # Train the network
    # [...]

答案 1 :(得分:0)

如果已定义

2

然后您可以执行类似的操作

train_set = torchvision.datasets.CIFAR10(root='~/datasets/', train=True, download=True, transform=(transform['train'])) ,其中train_set.data[index]是您想要的特定示例的index

现在,您可以使用包含这些特定示例的新数据集来重新定义index类,然后就可以使用了。