该模型已获悉,但无法预测?

时间:2019-04-01 12:10:48

标签: linear-regression pytorch

我按照教程编写了一些有关波士顿价格的线性回归代码,它工作得很好,损失也越来越小,当我想在matplotlib中绘制图形时,发现图形在我脑海中没有显示。 / p>

我搜索了,但无法解决我的问题。

import pandas as pd
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import torch
from torch.autograd import Variable

import matplotlib.pyplot as plt

if __name__ == '__main__':
    boston = load_boston()
    col_names = ['feature_{}'.format(i) for i in range(boston['data'].shape[1])]
    df_full = pd.DataFrame(boston['data'], columns=col_names)

    scalers_dict = {}
    for col in col_names:
        scaler = StandardScaler()
        df_full[col] = scaler.fit_transform(df_full[col].values.reshape(-1, 1))
        scalers_dict[col] = scaler

    x_train, x_test, y_train, y_test = train_test_split(df_full.values, boston['target'], test_size=0.2, random_state=2)

    model = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], 1), torch.nn.ReLU())

    criterion = torch.nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    n_epochs = 2000

    train_loss = []
    test_loss = []
    x_train = Variable(torch.from_numpy(x_train).float(), requires_grad=True)
    y_train = Variable(torch.from_numpy(y_train).float())

    for epoch in range(n_epochs):
        y_hat = model(x_train)
        loss = criterion(y_hat, y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss = loss.data ** (1/2)
        train_loss.append(epoch_loss)
        if (epoch + 1) % 250 == 0:
            print("{}:loss = {}".format(epoch + 1, epoch_loss))

    order = y_train.argsort()
    y_train = y_train[order]
    x_train = x_train[order, :]

    model.eval()

    predicted = model(x_train).detach().numpy()
    actual = y_train.numpy()
    print('predicted:", predicted[:5].flatten(), actual[:5])
    plt.plot(predicted.flatten(), 'r-', label='predicted')
    plt.plot(actual, 'g-', label='actual')
    plt.show()

为什么预测结果与[22.4413、22.4413,...]一样, 在图片中,它是一条水平线。 我是深度学习的初学者,非常感谢您的帮助!

0 个答案:

没有答案