我正在尝试使用PyTorch训练网络。数据集包括许多jpg和xml文件中的注释(标签和边框)。我编写了一个用于提取注释的基本数据加载器。
class dataLoader(Dataset):
def __init__(self, path, root, img_trfm=None, resize=None):
self.filename = get_file_name(path)
self.data = self.filename
self.root = root
self.resize = resize
self.img_resize = transforms.Resize((resize, resize))
self.img_trfm = img_trfm
self.data_len = len(self.data)
def __getitem__(self, index):
filename = (self.filename[index])
objects, num_objs = get_targets(filename+'.xml')
img_name = objects[-1]
img = Image.open(self.root+img_name+'.jpg')
img_trfm = self.img_trfm(img)
#adds a dictionary for name of classes to
#corresponding integers
for i in range(num_objs-1):
objects[i]['bndbox'] = torch.Tensor(objects[i]['bndbox'])
objects[i]['id'] = get_class_id(objects[i]['label'])
return img_trfm, objects[:-1]
def __len__(self):
return self.data_len
输出是单个图像和图像中每个对象的字典列表:
dataiter= iter(trainloader)
img, objects = dataiter.next()
print(img.shape)
print(objects)
torch.Size([1, 3, 512, 640])
[{'label': ['person'],
'bndbox': tensor([[444., 220., 27., 65.]]),
'id': tensor([0])},
{'label': ['person'],
'bndbox': tensor([[468., 220., 26., 66.]]),
'id': tensor([0])},
{'label': ['person?'],
'bndbox': tensor([[415., 224., 20., 33.]]),
'id': tensor([3])}]
我还写了一个基本的训练功能,用于仅在标签ID上训练网络:
for epoch in range(1):
running_loss = 0.0
for data in trainloader:
images, objects = data
for k in objects:
label = k['id']
#label = label.type(torch.FloatTensor)
print(label)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, label)
loss.backward()
print(loss)
optimizer.step()
当我获得损失值时,它似乎正在工作:
tensor(0.46, grad_fn=<NllLossBackward>)
tensor(0.46, grad_fn=<NllLossBackward>)
tensor(2.13, grad_fn=<NllLossBackward>)
tensor(0.45, grad_fn=<NllLossBackward>)
tensor(0.44, grad_fn=<NllLossBackward>)
tensor(1.57, grad_fn=<NllLossBackward>)
tensor(0.43, grad_fn=<NllLossBackward>)
tensor(0.43, grad_fn=<NllLossBackward>)
tensor(2.30, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
tensor(1.66, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
tensor(2.12, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
tensor(0.42, grad_fn=<NllLossBackward>)
......
但是,在这些损失值之后,我不断收到以下错误消息:
File "/miniconda3/envs/env_PyTorch/lib/python3.7/xml/etree/ElementTree.py", line 598, in parse
self._root = parser._parse_whole(source)
File "<string>", line unknown
ParseError: no element found: line 1, column
任何建议将不胜感激。