如何将图像加载到Pytorch DataLoader?

时间:2018-04-26 21:46:07

标签: python pytorch

数据加载和处理的pytorch教程是一个非常具体的例子,有人可以帮我看一下这个函数应该是什么样的更通用的简单图像加载吗?

教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

我的数据:

我在以下文件夹结构中将MINST数据集作为jpg。 (我知道我可以使用数据集类,但这纯粹是为了看看如何在没有csv或复杂功能的情况下将简单图像加载到pytorch中。)

文件夹名称是标签,图像是灰度级的28x28 png,不需要转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

2 个答案:

答案 0 :(得分:10)

这就是我为pytorch 4.1做的事情

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network

答案 1 :(得分:7)

如果您正在使用mnist,则通过torchvision在pytorch中已有预设 你可以做到

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

如果要推广到图像目录(与上面相同的导入),可以执行

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

然后你可以迭代它:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

有很多方法可以为图像数据集加载推广pytorch,我所知道的方法是子类化torch.utils.data.dataset