我无法在sklearn中顺利进行多项式回归

时间:2019-01-12 15:16:03

标签: python matplotlib scikit-learn

对于使用sklearn在python中实现的多项式回归代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
import pandas as pd

m=100
X=6*np.random.rand(m,1)-3
y=0.5*X**3+X+2+np.random.rand(m,1)

poly_features=PolynomialFeatures(3)
X_poly=poly_features.fit_transform(X)
lin_reg=LinearRegression()

X_new=[[0],[3]]

lin_reg.fit(X_poly,y)

plt.plot(X,y,"b.")
plt.plot(X, lin_reg.predict(poly_features.fit_transform(X)),  "r-")

plt.show()

输出显示为

enter image description here

但是我想获得一条平滑的预测线。如何获得?

1 个答案:

答案 0 :(得分:0)

问题是X数组未排序。因此,当您使用线-r绘制数据时,它将按未排序的X数据点的顺序连接数据点。因此,您会看到随机的线网。有标记的绘图的顺序无关紧要,因为您只是在画没有线的点。

解决方案是对X数据进行排序,然后将排序后的X数据传递到plot命令,并相应地传递到fit_transform

shape = X.shape
X = np.sort(X.flatten())
plt.plot(X, lin_reg.predict(poly_features.fit_transform(X.reshape((shape)))),  "r-", lw=2)

enter image description here