RuntimeError:断言`cur_target> = 0 && cur_target <n_classes'失败

时间:2019-05-04 09:15:35

标签: python python-3.x machine-learning anaconda cross-entropy

我得到:

  

RuntimeError:断言`cur_target> = 0 && cur_target

运行此代码时:

    criterion = nn.CrossEntropyLoss()
    #Define the optimizer
    optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
    epochs=20
    for epoch in range(epochs):
        print ("epoch #", epoch)
        running_loss=0.0
        for i, data in enumerate(train_loader,0):
            inputs,labels=data
            inputs,labels= inputs.to(device),labels.to(device)
            optimizer.zero_grad()   
            #train
            output=net(inputs)
            loss=criterion(output,labels)

    print ("loss: ", loss.item())
    running_loss+=loss.item()
    loss.backward()
    optimizer.step()
    print ('Finished Training')

2 个答案:

答案 0 :(得分:0)

该异常表明您的标签之一超出范围。 也许它们从1开始而不是0?尝试将它们打印出来。

答案 1 :(得分:0)

我遇到了这个确切的错误(是的,它来自 Pytorch),我会发布我的解决方案,以防其他人可以从中受益。

我的错误是因为我的分类器只有 2 个输出,但数据有 3 个标签。

通过确保分类器给出 3 个类来修复。