如何从.pt文件创建Pytorch数据集?

时间:2019-04-15 16:16:06

标签: python computer-vision pytorch mnist dcgan

我已经将保存为.pt文件的MNIST图像转换为Google驱动器中的文件夹。我在Colab中编写我的Pytorch代码。

我想使用这些文件,并创建一个将这些图像存储为张量的数据集。我该怎么办?

在训练过程中转换图像花费了太长时间。因此,将它们转换并将它们全部保存为.pt文件。我只想将它们作为数据集重新加载并在模型中使用。

1 个答案:

答案 0 :(得分:1)

您遵循的保存图像的方法确实是一个好主意。在这种情况下,您只需编写自己的Dataset类即可加载图像。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

class ReaderDataset(Dataset):
    def __init__(self, filename):
        # load the images from file

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch element

然后您可以按照以下步骤创建Dataloader。

train_dataset = ReaderDataset(filepath)
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.data_workers,
    collate_fn=batchify,
    pin_memory=args.cuda,
    drop_last=args.parallel
)
# args is a dictionary containing parameters
# batchify is a custom function that prepares each mini-batch