我开始使用Pytorch,目前正在研究一个项目,该项目使用简单的前馈神经网络进行线性回归。问题是我在Pytorch中找不到任何可以像Keras或SKlearn中那样获得线性回归模型精度的东西。在keras中,仅需在compile函数中设置metrics=["accuracy"]
即可。我在Pytorch的文档和官方网站中进行了搜索,但没有找到任何东西。似乎该API在Pytorch中不存在。我知道我可以在训练过程中观察损失,也可以简单地得到测试损失,并据此知道损失是否减少了,但是我想使用Keras结构来获取损失值和准确度值。 Keras的方法看起来更加清晰。我还尝试使用sklearn的r2_score实现精度函数,但它给了我一些奇怪的值:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
def train(model, optimizer, loss_fn):
def train_step(x, y):
model.train()
optimizer.zero_grad()
out = model(x)
loss = loss_fn(out, y)
loss.backward()
optimizer.step()
return loss.item()
return train_step
def fit(epochs=100):
train_func = train(model, optimizer, criterion)
count, total = 0, 0
loss_list, accuracy_list, iters = [], [], []
for e in range(epochs):
for X, y in train_loader:
loss = train_func(X, y)
count += 1
total += len(y)
if count % 50 == 0:
print("loss= ", loss)
loss_list.append(loss)
iters.append(total)
if count % 100 == 0:
model.eval() # im not sure if we can do this in pytorch. I mean evaluating the model while training! it would be great if you tell me whether this is ok or not
out = model(X)
out = out.detach().numpy()
y = y.detach().numpy()
accuracy = r2_score(y, out) # r2_score is the scikit learn r2 score function.
print("accuracy = ", accuracy) # here i get wierd values and it doesn't get better over time, in contrast the loss decreased over time
accuracy_list.append(accuracy)
return iters, loss_list, accuracy_list
我知道在分类问题的情况下如何实现精度功能,因为它使用离散值。这对我来说很清楚,因为实现起来很简单明了。我只能查看模型做出了哪个正确的预测,然后计算准确性。但是在这种情况下,我具有连续的值,所以这就是为什么我自己无法实现该功能,而Pytorch对此没有内置功能,这令我感到惊讶。所以有人可以告诉我如何实现它,或在哪里找到它的实现?
另一件事是在何处使用评估,以及在何处通过调用eval函数在评估模式下设置模型。我应该像在代码中一样在训练过程中使用它,还是应该训练然后在训练后进行测试;如果我在训练过程中进行测试,是否应该像在此那样调用eval函数,否则当循环返回训练时会影响训练模式?我在Pytorch中也找不到它,这是交叉验证。如果没有像Keras这样的API,该如何在pytorch中实现?
答案 0 :(得分:-2)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
在此处查看更多信息:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html