使用sklearn.linear_model绘制奇怪的图

时间:2017-11-18 01:11:45

标签: python scikit-learn linear-regression

我通常使用MATLAB,但想要自己去学习一些关于Python的东西。我尝试了由youtuber引入的线性回归代码。这是代码:

import pandas as pd
from sklearn import linear_model
import matplotlib.pyplot as plt

#read data
dataframe = pd.read_fwf('brain_body.txt')
x_values = dataframe[['Brain']]
y_values = dataframe[['Body']]

#train model on data
body_reg = linear_model.LinearRegression()
body_reg.fit(x_values,y_values)

#visualize results
plt.scatter(x_values,y_values)
plt.plot(x_values,body_reg.predict(x_values))
plt.show() 

但我最终得到了一个非常奇怪的情节(我使用的是Python 3.6): 1

这是细节的一部分: 2

显然,某些事情缺失或错误。

可以在https://github.com/llSourcell/linear_regression_demo/blob/master/brain_body.txt

中找到brain_body.txt的数据

欢迎任何建议或建议。

更新

我试过了sera的代码,这就是我得到的: 3

这很有趣也很奇怪。我发现我的数据文件有问题,或者我的Python中缺少某些东西,但我只是将原始数据复制并粘贴到记事本中并保存为.txt;我试过Python 3.6和2.7以及Pycharm和Spyder ......所以我不知道...... 顺便说一下,YouTube视频是here

@sascha @Moritz @sera我让我的朋友运行相同的代码和数据文件,一切都很好。换句话说,我的Python有问题,我不知道为什么。让我尝试另一台计算机和/或尝试早期版本的python。

我试过,但没有改变。以下是我用于安装Python的两种不同方法: 1.安装Python(例如版本3.6);安装Pycharm;安装包Pandas,scikit-learn ... 2.安装Anaconda

解决

感谢@Marc Bataillou的建议。这是与不同版本的matplotlib相关的问题。问题发生在2.1.0版本中。我试过2.0.2并发现原始代码在旧版本中运行良好;显然,一些变化是从2.0.2到2.1.0。感谢您的所有努力。

2 个答案:

答案 0 :(得分:0)

你应该使用 plt.scatter(x_values.values,y_values.values) 代替 plt.scatter(x_values,y_values)

我希望它有效!

答案 1 :(得分:-1)

您可以使用以下代码显示结果。我使用交叉验证进行预测。如果模型是完美的,那么所有点都将在绘制的线上。

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_predict
from sklearn import linear_model

#read data
dataframe = pd.read_fwf('brain_body.txt')

x_values = dataframe[['Brain']]
y_values = dataframe[['Body']]

#model on data
body_reg = linear_model.LinearRegression()

# cross_val_predict returns an array of the same size as `y` where each entry
# is a prediction obtained by cross validation:
predicted = cross_val_predict(body_reg, x_values, y_values, cv=10)

fig, ax = plt.subplots()
ax.scatter(y_values, predicted, edgecolors=(0, 0, 0))
ax.plot([y_values.min(), y_values.max()], [y_values.min(), y_values.max()], 'k--', lw=4)
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()

<强>结果:

enter image description here

数据

https://ufile.io/p7x0r