如何使用pytorch加载数据以及如何进行数据扩充

时间:2020-03-05 13:48:02

标签: deep-learning pytorch

我是Pytorch的新手,正在执行图像分类问题,但是我不了解如何从加载目录加载图像,请帮助我如何加载数据图像数据以及如何扩充图像。

这里的数据如下:

train=pd.read_csv('dataset/train.csv')
test=pd.read_csv('dataset/test.csv') 
train.head()
Image   Class
0   image7042.jpg   Food
1   image3327.jpg   misc
2   image10335.jpg  Attire
3   image8019.jpg   Food
4   image2128.jpg   Attire

这是我的图片文件夹:

file_path='dataset/Train Images'

1 个答案:

答案 0 :(得分:0)

您可以为此使用torchvision。假设您已将所有训练/测试图像分为两个文件夹,分别称为traintest,下面是一些有关如何加载和遍历图像的示例代码:

import torchvision
from torchvision import datasets, transforms

def load_dataset(data_path):
    dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([torchvision.transforms.ToTensor()])
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )
    return data_loader

train_loader = load_dataset(f'{base_dir}/train')
test_loader = load_dataset(f'{base_dir}/test')

for batch_idx, (data, _) in enumerate(train_loader):
   # Train model

...

for batch_idx, (data, _) in enumerate(test_loader):
   # Evaluate model

如果要分批训练模型,可以增加batch_size,在transform参数中添加变形器以增强图像,以及其他许多事情。

查看文档:{​​{3}}