我正在为PyTorch中的自定义数据集寻找对象检测。
教程here提供了一个片段,用于对自定义对象分类
使用预训练的模型model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)
我尝试使用类似的方法通过更快的rcnn模型进行对象检测。
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
for param in model.parameters():
param.requires_grad = False
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 1 # 1 class (person) + background
print(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
PyTorch引发这些错误。这种方法首先正确吗?
Epoch 0/24
----------
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-69-527ca4db8e5d> in <module>()
----> 1 model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
2 frames
/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
43 """
44 if self.training and targets is None:
---> 45 raise ValueError("In training mode, targets should be passed")
46 original_image_sizes = [img.shape[-2:] for img in images]
47 images, targets = self.transform(images, targets)
ValueError: In training mode, targets should be passed
是否可以修改此示例以进行自定义对象检测? https://www.learnopencv.com/faster-r-cnn-object-detection-with-pytorch/
答案 0 :(得分:0)
错误消息说明了一切。您需要传递一对image, target
来训练模型,target
在其中。是一本字典,其中包含有关边界框,标签和蒙版的信息。
有关更多信息和全面的教程,请查看https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html