我正在使用matplotlib.pyplot
在jupyter笔记本中绘制连续曲线。我使用了以下代码:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1,1), y, random_state = 0)
poly = PolynomialFeatures(degree=9)
X_train_p = poly.fit_transform(X_train)
X_test_p = poly.fit_transform(X_test)
plt.figure(figsize=(5,5))
plt.title("deg={}".format(9))
plt.plot(X_train, y_train.reshape(-1,1), 'r')
plt.show()
我希望数据点将通过直线连续连接,但是结果像这样:
我尝试使用X_train
重塑y_train
和.reshape()
的多种形式,但没有获得预期的结果。