我有一个reference.csv文件,该文件具有三列:Type,Class和Path。这是前5个示例行:
"Type","Class","Path"
"train","A","./path1/001.jpg"
"train","A","./path2/002.jpg"
"test","C","./path3/003.jpg"
"train","B","./path4/001.jpg"
"test","B","./path5/002.jpg"
...
以一种对观众更友好的格式:
|----------------------|------------------|------------------|
| Type | Class | Path |
|----------------------|------------------|------------------|
| train | A | ./path1/001.jpg |
|----------------------|------------------|------------------|
| train | A | ./path2/002.jpg |
|----------------------|------------------|------------------|
| train | C | ./path3/003.jpg |
|----------------------|------------------|------------------|
| test | B | ./path4/001.jpg |
|----------------------|------------------|------------------|
| test | B | ./path5/002.jpg |
|----------------------|------------------|------------------|
我想创建一个数据集类(torch.utils.data.Dataset)来读取图像,以便可以使用DataLoader(torch.utils.data.DataLoader)。
使用参考表创建自定义数据集的正确方法是什么?
答案 0 :(得分:1)
如果我们要构建一个自定义数据集来读取此csv文件中的图像位置,则可以执行以下操作。您的逻辑可能会有所不同。
class CustomDatasetFromImages(Dataset):
def __init__(self, csv_path):
"""
Args:
csv_path (string): path to csv file
img_path (string): path to the folder where images are
transform: pytorch transforms for transforms and tensor conversion
"""
# Transforms
self.to_tensor = transforms.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Third column is for an operation indicator
self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
# Calculate len
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
# Check if there is an operation
some_operation = self.operation_arr[index]
# If there is an operation
if some_operation:
# Do some operation on image
# ...
# ...
pass
# Transform image to tensor
img_as_tensor = self.to_tensor(img_as_img)
# Get label(class) of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
# Call dataset
custom_mnist_from_images = \
CustomDatasetFromImages('../data/mnist_labels.csv')