Scikit - SGDRegressor不合适

时间:2018-05-27 10:58:38

标签: python machine-learning scikit-learn linear-regression

您好我正在尝试使用scilearn来填充一小组数据。

import numpy as np
from sklearn import linear_model, model_selection

X = np.array([[86.5999984741211,    9.10000038146973,   14.3000001907349,1],
            [66.9000015258789,  17.3999996185303,   11.5,1],
            [66.3000030517578,  20              ,   10.6999998092651,1],
            [78.6999969482422,  15.3999996185303,   12.1000003814697,1],
            [76.1999969482422,  18.2000007629395,   12.5,1],
            [84.4000015258789,  9.89999961853027,   12.1000003814697,1],
            [79.1999969482422,  8.5             ,   10.1000003814697,1],
            [77.5           ,   10.1999998092651,   11.3999996185303,1],
            [74.4000015258789,  17.7999992370605,   10.6000003814697,1],
            [870.9000015258789, 13.5            ,   13,1],
            [80.0999984741211,  8               ,   9.10000038146973,1],
            [80.0999984741211,  10.3000001907349,   9,1],
            [79.6999969482422,  13.1000003814697,   9.5,1],
            [76.1999969482422,  13.6000003814697,   11.5,1],
            [75.5999984741211,  12.1999998092651,   10.8000001907349,1],
            [81.3000030517578,  13.1000003814697,   9.89999961853027,1],
            [64.5999984741211,  20.3999996185303,   10.6000003814697,1],
            [68.3000030517578,  26.3999996185303,   14.8999996185303,1],
            [80             ,   10.6999998092651,   10.8999996185303,1],
            [78.4000015258789,  9.69999980926514,   12,1],
            [78.8000030517578,  10.6999998092651,   10.6000003814697,1],
            [76.8000030517578,  15.3999996185303,   13,1],
            [82.4000015258789,  11.6000003814697,   9.89999961853027,1],
            [73.9000015258789,  16.1000003814697,   10.8999996185303,1],
            [64.3000030517578,  24.7000007629395,   14.6999998092651,1],
            [81             ,   14.8999996185303,   10.8000001907349,1],
            [70             ,   14.3999996185303,   11.1000003814697,1],
            [76.6999969482422,  11.1999998092651,   8.39999961853027,1],
            [81.8000030517578,  10.3000001907349,   9.39999961853027,1],
            [82.1999969482422,  9.89999961853027,   9.19999980926514,1],
            [76.6999969482422,  10.8999996185303,   9.60000038146973,1],
            [75.0999984741211,  17.3999996185303,   13.8000001907349,1],
            [78.8000030517578,  9.80000019073486,   12.3999996185303,1],
            [74.8000030517578,  16.3999996185303,   12.6999998092651,1],
            [75.6999969482422,  13              ,   11.3999996185303,1],
            [74.5999984741211,  19.8999996185303,   11.1000003814697,1],
            [81.5           ,   11.8000001907349,   11.3000001907349,1],
            [74.6999969482422,  13.1999998092651,   9.60000038146973,1],
            [72             ,   11.1999998092651,   10.8000001907349,1],
            [68.3000030517578,  18.7000007629395,   12.3000001907349,1],
            [77.0999984741211,  14.1999998092651,   9.39999961853027,1],
            [67.0999984741211,  19.6000003814697,   11.1999998092651,1],
            [72.0999984741211,  17.3999996185303,   11.8000001907349,1],
            [85.0999984741211,  10.6999998092651,   10,1],
            [75.1999969482422,  9.69999980926514,   10.3000001907349,1],
            [80.8000030517578,  10              ,   11,1],
            [83.8000030517578,  12.1000003814697,   11.6999998092651,1],
            [78.5999984741211,  12.6000003814697,   10.3999996185303,1],
            [66             ,   22.2000007629395,   9.39999961853027,1],
            [83             ,   13.3000001907349,   10.8000001907349,1],
            [73.0999984741211,  26.3999996185303,   22.1000003814697,1]])

y = np.array([761,
            780,
            593,
            715,
            1078,
            567,
            456,
            686,
            1206,
            723,
            261,
            326,
            282,
            960,
            489,
            496,
            463,
            1062,
            805,
            998,
            126,
            792,
            327,
            744,
            434,
            178,
            679,
            82,
            339,
            138,
            627,
            930,
            875,
            1074,
            504,
            635,
            503,
            418,
            402,
            1023,
            208,
            766,
            762,
            301,
            372,
            114,
            515,
            264,
            208,
            286,
            2922])

model = linear_model.SGDRegressor(max_iter=0x7FFFFFFF, tol=1e-12, learning_rate="constant", eta0=.1, shuffle=False)
"""model = linear_model.Lasso(max_iter=0x7FFFFFF,tol=1e-12)"""

model.fit(X,y)
print(model.coef_)

print (model.score(X,y))
"""
for i in range(0,len(X)):
    print (np.dot(X[i],model.coef_))"""

Ridge / Lasso / ElasticNet在某种程度上适合(~0.7分),但即使我设置了超高迭代和超低tol值,SGDRegressor也无法接近这些值。

调整max_iter或tol对结果没有任何影响,我不断获得巨大的系数。

1 个答案:

答案 0 :(得分:2)

在应用渐变下降技术之前,您需要确保缩放功能。看看你的X,这应该可以解决问题。

from sklearn.preprocessing import StandardScaler
X_scaled = StandardScaler().fit_transform(X)