我想创建一个用于对象定位的CNN网络(假设只有一个对象)。为此,我使用了一些常规图层,最后我想获得到原点最近和最远的角。我也在使用自定义损失函数,该函数是(100-工会上的相互作用,以%为单位)。损失没有收敛。可能是什么问题? Wheather反向传播是否可以与此网络一起使用或解决其他问题?下面是代码: 出于幻觉目的,请参见。
网络:
class convnet(nn.Module):
def __init__(self):
super(convnet, self).__init__()
self.conv1 = nn.Conv2d(1, 4, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=3,stride=3)
self.conv2 = nn.Conv2d(4, 8, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=3,stride=3)
self.conv3 = nn.Conv2d(8, 16, kernel_size=5)
self.pool3 = nn.MaxPool2d(kernel_size=3,stride=3)
self.fc1 = nn.Linear(5040, 1000)
self.fc2 = nn.Linear(1000, 84)
self.fc3 = nn.Linear(84, 4)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.pool3(x)
x = x.view(-1, 5040)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
a = self.sigmoid(self.fc3(x))
c = torch.zeros(a.shape[0], 2)
for idx, x in enumerate(a):
d1 = x[0] ** 2 + x[1] ** 2
d2 = x[2] ** 2 + x[3] ** 2
d3 = x[0] ** 2 + x[3] ** 2
d4 = x[2] ** 2 + x[1] ** 2
dmin = min(d1, d2, d3, d4)
if d1 == dmin:
c[idx] = torch.tensor([x[0], x[1]])
elif d2 == dmin:
c[idx] = torch.tensor([x[2], x[3]])
elif d3 == dmin:
c[idx] = torch.tensor([x[0], x[3]])
elif d4 == dmin:
c[idx] = torch.tensor([x[2], x[1]])
m = torch.tensor([[640, 480, 640, 480]]).type(torch.DoubleTensor).cuda()
return c*m
def sigmoid(self, z):
return 1/(1+torch.exp(-z))
丢失功能:
def iou(box_a, box_b):
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
inter =inter[:, :, 0] * inter[:, :, 1]
area_a = ((box_a[:, 2]-box_a[:, 0]) *
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)
area_b = ((box_b[:, 2]-box_b[:, 0]) *
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)
union = area_a + area_b - inter
return ((inter / union)*100/float(A*A)).sum()
def criterion(output, labels):
return (100-iou(output, labels))
您可以在此处查看完整的代码:link