使用pyplot

时间:2017-11-13 08:59:46

标签: python matplotlib machine-learning linear-regression

我正在使用pyplot绘制线性回归模型。以下是我的代码。

plt.scatter(X_train, y_train, color ='red')
plt.show()

当我使用上面的代码进行绘图时,绘图如下所示: Scatter Plot

然后,我使用以下代码绘制了折线图:

plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.show()

按预期显示一行。 Line Grpah

但是当我尝试将它们一起绘制时,图形会变得混乱,如下所示:

plt.scatter(X_train, y_train, color ='red')
plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.show()

Linear Regression

如果我需要做任何额外的编码来正确绘制线性回归图,请告诉我。

1 个答案:

答案 0 :(得分:1)

Pyplot按照它们在X_train中出现的顺序连接各点之间的点,但通常对它的排序一无所知。它很少排序。在绘制数组之前,您需要对数组进行排序。

sorted_indices = numpy.argsort(X_train)
sorted_X = X_train[sorted_indices]
plt.plot(sorted_X, regressor.predict(sorted_X), color = 'blue')