过滤数据集的线性回归

时间:2020-05-19 01:29:38

标签: python pandas numpy matplotlib linear-regression

在最终确定了数据集并能够绘制图形之后,我一直在尝试使用线性回归拟合曲线。我尝试了几种方法,但是都没有给我任何结果,我认为这是由于我的数据已被过滤。这是我的代码:

from matplotlib import pyplot as plt
import numpy as np
from pandas import DataFrame
from sklearn.linear_model import LinearRegression
from matplotlib.pyplot import figure

figure(num=None, figsize=(100, 100), dpi=100, facecolor='w', edgecolor='k')

plt.rc('font', size=100)          # controls default text sizes
plt.rc('axes', titlesize=100)     # fontsize of the axes title
plt.rc('axes', labelsize=100)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=30)    # fontsize of the tick labels
plt.rc('ytick', labelsize=60)    # fontsize of the tick labels
plt.rc('legend', fontsize=100)    # legend fontsize
plt.rc('figure', titlesize=100)

plt.xticks(rotation=90)


ds = pd.read_csv("https://covid.ourworldindata.org/data/owid-covid-data.csv")
df = DataFrame(ds, columns = ['date', 'location', 'new_deaths', 'total_deaths'])

df = df.replace(np.nan, 0)

US = df.loc[df['location'] == 'United States']


plt.plot_date(US['date'],US['new_deaths'], 'blue', label = 'US', linewidth = 5)
#plt.plot_date(US['date'],US['total_deaths'], 'red', label = 'US', linewidth = 5)

#linear_regressor = LinearRegression()  # create object for the class
#linear_regressor.fit(US['date'], US['new_deaths'])  # perform linear regression
#Y_pred = linear_regressor.predict(X)  # make predictions

#m , b = np.polyfit(x = US['date'], y = US['new_deaths'], deg = 1)






plt.title('New Deaths per Day In US')
plt.xlabel('Time')
plt.ylabel('New Deaths')
plt.legend()
plt.grid()
plt.show()


我知道这个问题已经问了好几千遍了,所以如果有没有我没遇到的帖子,请把它链接到我身上。谢谢你们! :D

1 个答案:

答案 0 :(得分:0)

使用sklearn的LinearRegression,您可以执行以下操作以适合回归:

regr = LinearRegression()
regr.fit(US['date'].values.reshape(-1, 1), US['new_deaths'])

要绘制它:

# plot the original points
plt.plt(US['date'], US['new_deaths'])

# plot the fitted line. To do so, first generate an input set containing
# only the max and min limits of the x range
trendline_x = np.array([US['date'].min(), US['date'].max()]).reshape(-1, 1)
# predict the y values of these two points
trendline_y = regr.predict(trendline_x)
# plot the trendline
plt.plot(trendline_x, trendline_y)

如果您只是想欣赏视觉效果,那么Seaborn的lmplot是一种方便又美观的选择。

相关问题