在pytorch中建立多标签网络

时间:2019-08-01 19:27:32

标签: pytorch

试图解决多标签图像分类问题,即8类多标签问题。我的标签看起来像这样

    [0. 1. 0. 0. 1. 0. 0. 0.]
    [0. 1. 0. 0. 0. 0. 0. 0.]
    [0. 1. 0. 1. 0. 0. 0. 0.]

代码看起来像这样

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 8)
model = model.cuda()
# Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
loss_fn = torch.nn.BCELoss()



# def train(epoch):
for epoch in range(15):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # import pdb;pdb.set_trace()
        data, target = data.cuda(), target.float().cuda()
        optimizer.zero_grad()
        output = model(data)
        output = torch.sigmoid(output)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

现在结果不好,我做错什么了吗? 任何建议都将真正有用。谢谢。

0 个答案:

没有答案