我想知道如何在PyTorch中将数据加载器用于我的自定义文件结构。我已经阅读了PyTorch文档,但是所有这些都带有单独的类文件夹。
我的文件夹结构由2个文件夹(称为训练和验证)组成,每个文件夹具有2个子文件夹(称为images和json_annotations)。 “ images”文件夹中的每个图像都有多个对象(例如汽车,自行车,人等),并且每个对象都带有注释,并具有单独的JSON文件。遵循标准的可可注释。我的意图是制作一个可以从视频进行实时分类的神经网络。
编辑1: 我已经按照FábioPerez的建议进行了编码。
class lDataSet(data.Dataset):
def __init__(self, path_to_imgs, path_to_json):
self.path_to_imgs = path_to_imgs
self.path_to_json = path_to_json
self.img_ids = os.listdir(path_to_imgs)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
img_id = os.path.splitext(img_id)[0]
img = cv2.imread(os.path.join(self.path_to_imgs, img_id + ".jpg"))
load_json = json.load(open(os.path.join(self.path_to_json, img_id + ".json")))
#n = len(load_json)
#bboxes = load_json['annotation'][n]['segmentation']
return img, load_json
def __len__(self):
return len(self.image_ids)
当我尝试
l_data = lDataSet(path_to_imgs = '/home/training/images', path_to_json = '/home/training/json_annotations')
我用l_data [] [0]-图像获取l_data和使用json的l_data 。现在我很困惑。如何在PyTorch中通过微调示例availalbe使用它?在该示例中,数据集和数据加载器如下所示。
https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}
答案 0 :(得分:1)
您应该可以使用data.Dataset
来实现自己的数据集。您只需要实现__name__ = 'DEMO'
class GradeHandler(object):
EMNER = ["INFO100","INFO104","INFO110","INFO150","INFO125"]
FAGKODER= [["Informasjonsvitenskap","INF"],["Kognitiv vitenskap","KVT"]]
KARAKTERER = [["INFO100","C"],["INFO104","B"],["INFO110","E"]]
def __init__(self):
self.Emner = self.EMNER
self.FagKoder = self.FAGKODER
self.Karakterer = self.KARAKTERER
self.__create_grade_dict()
def remove_subject(self, subject_name):
"""
Remove a subject ot the classes class list variable.
"""
try:
self.Emner = [i for i in self.EMNER if i != subject_name]
self.__create_grade_dict()
except ValueError:
pass
def add_subject(self, subject_name):
"""
Append a subject ot the classes class list variable.
"""
if not subject_name in Emner:
self.Emner.append(subject_name)
self.__create_grade_dict()
def __create_grade_dict(self, grade_dict=None):
"""
Split grades matrix into separate parts; Create and set a dictionary of values.
"""
if grade_dict is None:
self.grade_dict = dict()
sub, grade = zip(*self.Karakterer)
karakterer_dict = {k:v for k, v in list(zip(sub, grade))}
for i in self.Emner:
if i in karakterer_dict.keys():
self.grade_dict[i] = karakterer_dict[i]
else:
self.grade_dict[i] = ''
def update_grade(self, subject_name, grade='A'):
"""
Update a grade in the grade dictionary.
Will also add a subject if not alrady in the dictionary.
"""
try:
self.grade_dict[subject_name] = grade
except (KeyError, ValueError):
pass
def print_grades(self, subject_name=None):
"""
Print dictionary results.
"""
if subject_name is None:
for k, v in self.grade_dict.items():
print('{} {}'.format(k, v))
else:
if subject_name in self.grade_dict.keys():
print('{} {}'.format(subject_name, self.grade_dict[subject_name]))
if __name__ == 'DEMO':
### Create an instance of the GradeHandler and print initial grades.
gh = GradeHandler()
gh.print_grades()
### Append a class
gh.add_subject('GE0124')
gh.print_grades()
### Add grade
gh.update_grade('GE0124', 'B+')
gh.print_grades()
### Update grades
gh.update_grade('GE0124', 'A-')
gh.print_grades()
### Remove subject (will also remove grade.
gh.remove_subject('GE0124')
gh.print_grades()
和__len__
方法。
根据您的情况,您可以遍历图像文件夹中的所有图像(然后可以将图像ID存储在__getitem__
中的列表中)。然后,您使用传递给Dataset
的索引来获取相应的图像ID。使用此图像ID,您可以读取相应的JSON文件并返回所需的目标数据。
类似这样的东西:
__getitem__
在class YourDataLoader(data.Dataset):
def __init__(self, path_to_imgs, path_to_json):
self.path_to_imags = path_to_imgs
self.path_to_json = path_to_json
self.image_ids = iterate_through_images(path_to_images)
def __getitem__(self, idx):
img_id = self.image_ids[idx]
img = load_image(os.path.join(self.path_to_images, img_id)
bboxes = load_bboxes(os.path.join(self.path_to_json, img_id)
return img, bboxes
def __len__(self):
return len(self.image_ids)
中,您将获得目录中图像的所有ID(例如文件名)。
在iterate_through_images
中,您将读取JSON并获取所需的信息。
如果您要引用,我有一个JSON加载程序实现here。