多项式回归打印特征名称和学习函数的截距

时间:2018-08-26 16:25:21

标签: python machine-learning scikit-learn linear-regression

我正在尝试在进行多项式回归时打印scikit-learn学习的函数。

我根据this example和从this question中学到的内容编写了以下代码。我可以使用它,但是我不确定如何解释输出。

截距应该与系数分开,但是使用get_feature_names我得到了一个名为“ 1”的列,听起来像截距列的名称。

拟合模型后,该列的系数为零,但是模型截距的值为-0.122(不为零)。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.preprocessing import PolynomialFeatures


# function that we want to learn
def f(x):
    return x * np.sin(x)


# generate points and keep a subset of them
x = np.linspace(0, 10, 100)
rng = np.random.RandomState(0)
rng.shuffle(x)
x = np.sort(x[:20])
y = f(x)

# from 1d vector to 2d array, multiple rows with a single column
X = x[:, np.newaxis]

# add polynomial features to X
degree = 5
poly = PolynomialFeatures(degree)
poly_X = poly.fit_transform(X)
feature_names = ['x']
poly_feature_names = poly.get_feature_names(feature_names)

# fit the ridge linear regression model
clf = Ridge()
clf.fit(poly_X, y)

# print learned function in readable form
learned_f = 'learned function:\n'
intercept = clf.intercept_
learned_f += '{:+.3f}'.format(intercept)
coefficients = clf.coef_
assert(len(coefficients) == len(poly_feature_names))
assert(coefficients[0] == 0.0)  # the first coefficient is 0
for i in range(0, len(coefficients)):
    learned_f += ' {:+.3f} {}'.format(coefficients[i], poly_feature_names[i])
print(learned_f)

# generate points used to plot
x_plot = np.linspace(0, 10, 100)
X_plot = x_plot[:, np.newaxis]

# plot function we want to learn
lw = 2
plt.plot(x_plot, f(x_plot), color='cornflowerblue',
         linewidth=lw, label='ground truth')
plt.scatter(x, y, color='navy', s=30, marker='o',
            label='training points')

# plot points predicted by the learned function
poly_X_plot = poly.transform(X_plot)
y_plot = clf.predict(poly_X_plot)
plt.plot(x_plot, y_plot, color='teal', linewidth=lw,
         label='learned poly degree %d' % degree)

# place legend and show plot
plt.legend(loc='lower left')
plt.tight_layout()
plt.show()

输出,请注意前两个术语-0.122 +0.000 1

learned function:
-0.122 +0.000 1 +1.088 x +1.132 x^2 -0.950 x^3 +0.180 x^4 -0.010 x^5

该图有助于查看发生了什么情况 polynomial regression learning sine function

我的代码正确吗?如果是这样,我该如何解释名为“ 1”的列? 是拦截器的占位符还是其他?

这是我的规格

Linux-4.9.0-7-amd64-x86_64-with-debian-9.5
('Python', '2.7.13 |Anaconda 4.4.0 (64-bit)| (default, Dec 20 2016, 23:09:15) \n[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]')
('NumPy', '1.12.1')
('SciPy', '0.19.0')
('Scikit-Learn', '0.18.1')

0 个答案:

没有答案