我跟随ML教科书:使用scikit-learn掌握机器学习,虽然我的代码给了我正确的答案,但它与书中的内容并不匹配
首先它给了我这段代码:
import matplotlib.pyplot as plt
X = [[6], [8], [10], [14], [18]]
y = [[7], [9], [13], [17.5], [18]]
plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.plot(X, y, 'k.')
plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()
这给了我matplotlib中的这个图表:
这与我的结果相符。
然而,在下一步中它给了我这段代码:
from sklearn.linear_model import LinearRegression
# Training data
X = [[6], [8], [10], [14], [18]]
y = [[7], [9], [13], [17.5], [18]]
# Create and fit the model
model = LinearRegression()
model.fit(X, y)
print 'A 12" pizza should cost: $%.2f' % model.predict([12])[0]
这张图表:
那个chard与我的代码不匹配,它没有matplotlib图表制作功能。我试着阅读指南并制作自己的指南:
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt
X = [[6], [8], [10], [14], [18]]
y = [[7], [9], [13], [17.5], [18]]
model = LinearRegression()
model.fit(X, y)
z = np.array([12]).reshape(-1,1)
print ('A 12" pizza should cost: $%.2f' % model.predict(z)[0])
print ("\n" + "_" * 50 + "\n")
plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.plot(X, y, z, 'k.')
plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()
但这只是给了我这个奇怪的蓝色事物:
我刚才在python中使用数学,所以如果有人能给我更多关于如何解决这个问题的信息,我们将不胜感激。
答案 0 :(得分:1)
这个"奇怪的蓝色东西"你得到的是你的数据通过线段连接在一起;您的数据应使用plt.scatter
绘制,这会为您提供一堆积分。
您对回归线的计算是正确的,需要修复的是如何在数据集上绘制该线:
拟合数据后,需要提取绘制回归线所需的值;您需要的数据是x轴每个末端的两个点(此处为x=0
和x=25
)。如果我们在这两个值上调用model.predict
,我们就会获得相应的预测。这些x值与它们相应的预测相结合形成了两个点,我们将用它来绘制线。
首先,我们提取预测值y0
和y25
。然后我们使用plt.plot
和点(0,y0)和(25,y25)来绘制绿色的回归线。
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt
X = [[6], [8], [10], [14], [18]]
y = [[7], [9], [13], [17.5], [18]]
model = LinearRegression()
model.fit(X, y)
z = np.array([12]).reshape(-1,1)
print ('A 12" pizza should cost: $%.2f' % model.predict(z)[0])
print ("\n" + "_" * 50 + "\n")
plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.scatter(X, y, z, 'k')
y0, y25 = model.predict(0)[0][0], model.predict(25)[0][0]
plt.plot((0, 25), (y0, y25), 'g')
plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()