我现在坚持解决这个问题两天了。我有一些数据点放在scatter plot
中并得到这个:
哪个好,但是现在我还想添加回归线,所以我从sklearn看了一下这个example并将代码更改为
import numpy as np
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score
degrees = [3, 4, 5]
X = combined[['WPI score']]
y = combined[['CPI score']]
plt.figure(figsize=(14, 5))
for i in range(len(degrees)):
ax = plt.subplot(1, len(degrees), i + 1)
plt.setp(ax, xticks=(), yticks=())
polynomial_features = PolynomialFeatures(degree=degrees[i], include_bias=False)
linear_regression = LinearRegression()
pipeline = Pipeline([("polynomial_features", polynomial_features), ("linear_regression", linear_regression)])
pipeline.fit(X, y)
# Evaluate the models using crossvalidation
scores = cross_val_score(pipeline, X, y, scoring="neg_mean_squared_error", cv=10)
X_test = X #np.linspace(0, 1, len(combined))
plt.plot(X, pipeline.predict(X_test), label="Model")
plt.scatter(X, y, label="CPI-WPI")
plt.xlabel("X")
plt.ylabel("y")
plt.legend(loc="best")
plt.title("Degree {}\nMSE = {:.2e}(+/- {:.2e})".format(degrees[i], -scores.mean(), scores.std()))
plt.savefig(pic_path + 'multi.png', bbox_inches='tight')
plt.show()
具有以下输出:
请注意,X
和y
都是DataFrames
,大小为(151, 1)
。如有必要,我也可以发布X和y的内容。
我想要的是一条很流畅的线条,但我似乎无法弄明白,如何做到这一点。
[编辑]
这里的问题是:如何获得单个平滑,弯曲的多项式线而不是具有看似随机模式的多个线。
[编辑2]
问题是,当我像这样使用linspace
时:
X_test = np.linspace(1, 4, 151)
X_test = X_test[:, np.newaxis]
我得到一个更随机的模式:
答案 0 :(得分:1)
诀窍是设置如下代码:
public class Main implements Runnable {
private static BufferedReader reader;
public static void main(String[] args) {
reader = new BufferedReader(new InputStreamReader(System.in));
new Thread(new Main()).start();
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
try {
//System.in.close(); // <-- undefined behavior
reader.close(); // <-- undefined behavior
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void run() {
try {
reader.readLine();
} catch (IOException ignore) {
}
}
}
产生以下结果(更好,单一平滑线)