试图解决多标签图像分类问题,即8类多标签问题。我的标签看起来像这样
[0. 1. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 1. 0. 0. 0. 0.]
代码看起来像这样
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 8)
model = model.cuda()
# Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
loss_fn = torch.nn.BCELoss()
# def train(epoch):
for epoch in range(15):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# import pdb;pdb.set_trace()
data, target = data.cuda(), target.float().cuda()
optimizer.zero_grad()
output = model(data)
output = torch.sigmoid(output)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
现在结果不好,我做错什么了吗? 任何建议都将真正有用。谢谢。