数据加载和处理的pytorch教程是一个非常具体的例子,有人可以帮我看一下这个函数应该是什么样的更通用的简单图像加载吗?
教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
我的数据:
我在以下文件夹结构中将MINST数据集作为jpg。 (我知道我可以使用数据集类,但这纯粹是为了看看如何在没有csv或复杂功能的情况下将简单图像加载到pytorch中。)
文件夹名称是标签,图像是灰度级的28x28 png,不需要转换。
data
train
0
3.png
5.png
13.png
23.png
...
1
3.png
10.png
11.png
...
2
4.png
13.png
...
3
8.png
...
4
...
5
...
6
...
7
...
8
...
9
...
答案 0 :(得分:10)
这就是我为pytorch 4.1做的事情
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network
答案 1 :(得分:7)
如果您正在使用mnist,则通过torchvision在pytorch中已有预设 你可以做到
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
shuffle=True, num_workers=2)
如果要推广到图像目录(与上面相同的导入),可以执行
class mnistmTrainingDataset(torch.utils.data.Dataset):
def __init__(self,text_file,root_dir,transform=transformMnistm):
"""
Args:
text_file(string): path to text file
root_dir(string): directory with all train images
"""
self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.name_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
image = Image.open(img_name)
image = self.transform(image)
labels = self.label_frame.iloc[idx, 0]
#labels = labels.reshape(-1, 2)
sample = {'image': image, 'labels': labels}
return sample
mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
root_dir = 'Downloads/mnist_m/mnist_m_train')
mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
然后你可以迭代它:
for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
print("training sample for mnist-m")
print(i_batch,sample_batched['image'],sample_batched['labels'])
有很多方法可以为图像数据集加载推广pytorch,我所知道的方法是子类化torch.utils.data.dataset