我目前正在使用相对稀疏的标签(大约1%的标签数据中的体素对应于目标类别)来训练用于二进制分类的3D CNN。
为了在培训过程中进行基本的健全性检查(例如,网络是否完全学习?),向网络展示一小部分经过精心挑选的培训示例子集会很方便,这些子集的目标类别标签要高于平均水平
根据Pytorch文档的建议,我实现了自己的dataset
类(继承自torch.utils.data.Dataset
),该类通过其__get_item__
方法向torch.utils.data.DataLoader
提供了培训示例。 / p>
在我发现的pytorch tutorials中,DataLoader
用作迭代器来生成训练循环,如下所示:
for i, data in enumerate(self.dataloader):
# Get training data
inputs, labels = data
# Train the network
# [...]
我现在想知道的是,是否存在一种简单的方法来加载单个或几个特定的训练示例(使用Dataset
的{{ 1}}方法)。但是,__get_item__
没有DataLoader
方法,并且反复调用__get_item__
直到达到所需的索引似乎不太好。
显然,解决此问题的一种可能方法是定义从抽象__next__
继承的自定义sampler
或batch_sampler
。但这似乎很容易检索一些特定的样本。
我想我在这里忽略了一些非常简单明显的东西。任何建议表示赞赏!
答案 0 :(得分:1)
以防万一有类似问题的人在某个时候遇到此问题:
我最终使用的快捷方法是直接访问与训练相关的dataloader
属性,从而绕过训练循环中的dataset
。假设我们想通过反复展示一个带有线性索引sample_idx
(由数据集类定义)的精选的训练示例,来快速检查我们的网络是否完全可以学习。
然后可以执行以下操作:
for i, _ in enumerate(self.dataloader):
# Get training data
# inputs, labels = data
inputs, labels = self.dataloader.dataset[sample_idx]
inputs = inputs.unsqueeze(0)
labels = labels.unsqueeze(0)
# Train the network
# [...]
答案 1 :(得分:0)
如果已定义
2
然后您可以执行类似的操作
train_set = torchvision.datasets.CIFAR10(root='~/datasets/', train=True,
download=True, transform=(transform['train']))
,其中train_set.data[index]
是您想要的特定示例的index
。
现在,您可以使用包含这些特定示例的新数据集来重新定义index
类,然后就可以使用了。