我正在Kaggle上进行仙人掌图像竞赛,我正在尝试为我的CNN使用PyTorch数据加载器。但是,我遇到了一个问题,无法为训练集设置标签。训练集图像位于文件夹中,标签位于csv文件中。这是我的代码。
train = torchvision.datasets.ImageFolder(root='../input/train',
transform=transform)
train.targets = torch.from_numpy(df['has_cactus'].values)
train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)
for i, data in enumerate(train_loader, 0):
print(data[1])
此代码输出全零的批处理张量,这显然是不正确的,因为绝大多数标签(如果要查看数据帧)都是1。我认为将标签分配给“ train.targets”存在问题。如果在分配其他标签之前打印了“ train.targets”,它将返回全零的张量,这与我得到的错误结果一致。如何解决此问题?
答案 0 :(得分:1)
我通常如下继承内置的DataSet类:
from torch.utils.data import DataLoader
class DataSet:
def __init__(self, root):
"""Init function should not do any heavy lifting, but
must initialize how many items are available in this data set.
"""
self.ROOT = root
self.images = read_images(root + "/images")
self.labels = read_labels(root + "/labels")
def __len__(self):
"""return number of points in our dataset"""
return len(self.images)
def __getitem__(self, idx):
""" Here we have to return the item requested by `idx`
The PyTorch DataLoader class will use this method to make an iterable for
our training or validation loop.
"""
img = images[idx]
label = labels[idx]
return img, label
现在,您可以创建此类的实例,
ds = Dataset('../input/train')
现在,您可以实例化DataLoader:
dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
这将创建一批数据,您可以按以下方式访问它们:
for image, label in dl:
print(label)
答案 1 :(得分:0)
您可以通过继承@Sai Krishnan提到的内置Dataset
类来创建自定义数据集加载器。
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image
VOC_CLASSES = ('background', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NUM_CLASSES = len(VOC_CLASSES) + 1
class customDataset(Dataset):
"""Pascal VOC 2007 Dataset"""
def __init__(self, list_file, img_dir, mask_dir, transform=None):
# list of images to load in a .txt file
self.images = open(list_file, "rt").read().split("\n")[:-1]
self.transform = transform
# note that in the .txt file the image names are stored without the extension(.jpg or .png)
self.img_extension = ".jpg"
self.mask_extension = ".png"
self.image_root_dir = img_dir
self.mask_root_dir = mask_dir
# can comment the line below
self.counts = self.__compute_class_probability()
def __len__(self):
return len(self.images)
def __getitem__(self, index):
name = self.images[index]
image_path = os.path.join(self.image_root_dir, name + self.img_extension)
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
image = self.load_image(path=image_path)
gt_mask = self.load_mask(path=mask_path)
data = {
'image': torch.FloatTensor(image),
'mask' : torch.LongTensor(gt_mask)
}
return data
def __compute_class_probability(self):
counts = dict((i, 0) for i in range(NUM_CLASSES))
for name in self.images:
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
raw_image = Image.open(mask_path).resize((224, 224))
imx_t = np.array(raw_image).reshape(224*224)
imx_t[imx_t==255] = len(VOC_CLASSES)
for i in range(NUM_CLASSES):
counts[i] += np.sum(imx_t == i)
return counts
def get_class_probability(self):
values = np.array(list(self.counts.values()))
p_values = values/np.sum(values)
return torch.Tensor(p_values)
def load_image(self, path=None):
# can use any other library too like OpenCV as long as you are consistent with it
raw_image = Image.open(path)
raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
imx_t = np.array(raw_image, dtype=np.float32)/255.0
return imx_t
# can comment the below function if not needed
def load_mask(self, path=None):
raw_image = Image.open(path)
raw_image = raw_image.resize((224, 224))
imx_t = np.array(raw_image)
imx_t[imx_t==255] = len(VOC_CLASSES)
return imx_t
一旦类准备就绪,您可以创建它的实例并使用它。
data_root = os.path.join("VOCdevkit", "VOC2007")
list_file_path = os.path.join(data_root, "ImageSets", "Segmentation", "train.txt")
img_dir = os.path.join(data_root, "JPEGImages")
mask_dir = os.path.join(data_root, "SegmentationClass")
objects_dataset = customDataset(list_file=list_file_path,
img_dir=img_dir,
mask_dir=mask_dir)
sample = objects_dataset[k]
image, mask = sample['image'], sample['mask']
image.transpose_(0, 2)
fig = plt.figure()
a = fig.add_subplot(1,2,1)
plt.imshow(image)
a = fig.add_subplot(1,2,2)
plt.imshow(mask)
plt.show()
请确保正确插入文件路径。另外,您还必须在customDataset()
类中正确加载标签。
注意:此代码段只是自定义数据加载器的外观示例。您必须对其进行适当的更改才能使其适合您的情况。