如何从DataLoader获取样本的文件名?

时间:2019-06-21 07:48:32

标签: python machine-learning pytorch torchvision

我需要编写一个文件,其中包含我训练的卷积神经网络的数据测试结果。该数据包括语音数据收集。文件格式需要为“文件名,预测”,但是我很难提取文件名。我这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

并且我正在尝试按以下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i]的问题在于它与test_loader的加载文件顺序不同步。我该怎么办?

2 个答案:

答案 0 :(得分:0)

好吧,这取决于您的Dataset的实现方式。例如,在torchvision.datasets.MNIST(...)情况下,您不能仅仅因为没有单个样本的文件名(MNIST样本为loaded in a different way)而检索文件名。

由于您没有展示自己的Dataset实现,因此我将告诉您如何使用torchvision.datasets.ImageFolder(...)(或任何torchvision.datasets.DatasetFolder(...))来实现:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

您可以看到在__getitem__(self, index),特别是here期间检索了文件的路径。

如果您实现了自己的Dataset(也许想支持shufflebatch_size > 1),那么我将在sample_fname上返回__getitem__(...)打电话并做这样的事情:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样,您就不必关心shuffle。并且如果batch_size大于1,则需要更改循环的内容,以获取更通用的内容,例如:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()

答案 1 :(得分:0)

通常情况下,DataLoader可以从其中的数据集中为您提供批次。

在提到单标签/多标签分类问题时,提到了@ @Barriel,DataLoader没有图像文件名,只有表示图像的张量以及类/标签。

但是,DataLoader构造函数在加载对象时会占用很小的空间(如果需要,可以与数据集一起包装目标/标签和文件名)

这样,DataLoader可能会以某种方式满足您的需求。