使用XGBoost提高预测准确性

时间:2018-09-30 19:00:35

标签: python scikit-learn xgboost

我有一个32x20的矩阵,正在尝试使用XGBoost(回归)。我正在遍历各行以产生样本不足的预测。我很惊讶XGBoost只返回3-4%的样本外错误(MAPE)。通过其他算法(glmboost,增强线性模型)运行数据时,MAPE约为1.8-2.5%。我很惊讶XGBoost如此不足。我怀疑我对超参数的优化不足。我包括一个在下面运行的gridsearch,但是错误没有改善。我想念什么?数据如下。

import pandas as pd
import numpy as np
import xgboost as xgb

Macro  = pd.read_csv("P:/Earnest/Old/R/InputFull.csv")
Macro.head
print(Macro)
print(Macro["y"])

l = len(Macro)-1


predictlist = []

for i in range(6):
    print('round = ', i+1)
    y = Macro.iloc[0:l-i,1:2]  
    x = Macro.iloc[0:l-i,2:21]
    t = Macro.iloc[l-i:l+1-i,2:21]

    data_dmatrix = xgb.DMatrix(data=x.values,label=y.values)

    model = xgb.XGBRegressor(objective ='reg:linear', colsample_bytree = 0.3, learning_rate = 0.05,
                max_depth = 10, alpha = 1, n_estimators = 1000).fit(x, y)
    model.fit(x, y)

    print(model.predict(t))
    predictlist.append(model.predict(t))

#Aggregate forecasts and MAPE
predictions  = pd.DataFrame(predictlist).sort_index(ascending=False)

y_hat = predictions.iloc[0:5] 
y_actual = pd.DataFrame(Macro.iloc[l-5:l,1:2])

mape = np.mean(np.abs((y_hat.values - y_actual.values)/y_actual.values)*100)

forecast = ((predictions.iloc[5].values[0]), mape)

print(predictions.to_string(index=False, header=False))
print(*predictions.iloc[5], mape)

我也尝试了这种网格搜索,但是出于示例性能并没有提高:

gbm_param_grid = {
'colsample_bytree': [0.3, 0.5, 0.7, 0.9],
'n_estimators': [50,100,200,500,750,1000],
'max_depth': [1,2,3,4,5],
'learning_rate': [0.1,.01]}

gbm = xgb.XGBRegressor(objective ='reg:linear')
model = GridSearchCV(estimator=gbm, param_grid=gbm_param_grid,
                    scoring='neg_mean_squared_error', cv=5, verbose=1)

数据:

json.dumps(y.values.tolist()) 
Out[106]: '[[668.39], [524.019], [609.181], [609.953], [730.648], [568.93], [676.2689999999999], [692.8939999999999], [856.832], [648.177], [758.5239999999999], [774.049], [905.8580000000001], [686.31], [811.253], [814.47], [1011.044], [739.01], [867.46], [825.258], [1013.4060000000001], [762.577], [890.568], [862.4910000000001], [1030.2], [761.2], [872.93], [892.77], [1089.12], [855.69], [992.454]]'

json.dumps(x.values.tolist()) 
Out[110]: '[[1.386295602, 0.48050087399999997, 1.553830985, 2.786758621, 1.0960262809999999, 0.781628874, 1.02046884, 0.624766272, 0.014193008, 0.00904413, 0.015319073999999998, 0.026494721000000002, 0.015739028999999998, 0.012574278999999999, 0.015046687, 0.008459999999999999, 0.0, 0.0, 0.0], [1.591017415, 0.9083151309999999, 0.969599703, 3.2589334030000003, 3.223092415, 2.34769243, 1.343920359, 1.4705725859999998, 0.014813803, 0.009677327, 0.008749057, 0.028008451, 0.046048472, 0.037521846, 0.01869757, 0.01790256, 1.0, 0.0, 0.0], [1.898588201, 1.017795959, 1.3778024830000002, 4.069130271000001, 3.044697588, 2.535950761, 1.512933476, 1.661024729, 0.018581214, 0.011645202, 0.01243379, 0.035481237, 0.044583948, 0.039372741, 0.020992887000000002, 0.020239406, 0.0, 1.0, 0.0], [1.951630259, 0.9668951809999999, 1.1501877459999998, 3.6250001610000004, 3.065987717, 2.7704967989999996, 1.414525236, 1.4780174369999999, 0.019477526000000002, 0.012054968000000001, 0.010672721000000001, 0.032024082999999995, 0.045732734000000004, 0.045536818, 0.02013372, 0.018537756000000002, 0.0, 0.0, 1.0], [2.7048712860000004, 1.413281211, 1.750041351, 4.518752372, 2.785182905, 2.630186539, 1.5383735859999998, 1.74886986, 0.029193196, 0.016686943, 0.019093442, 0.043645313, 0.043023802, 0.047415821, 0.024091094, 0.023658342000000002, 0.0, 0.0, 0.0], [1.8677043580000001, 0.888505628, 1.295459218, 3.238280115, 1.97323436, 2.065122701, 1.285088929, 1.3606826369999998, 0.01851557, 0.00976398, 0.012194571000000001, 0.029381077999999998, 0.028531972000000003, 0.035205199, 0.018439825, 0.017141158, 1.0, 0.0, 0.0], [2.079332177, 1.080109763, 1.6703727080000002, 4.221043246, 1.874771215, 2.343519883, 1.4866867019999999, 1.5800223780000002, 0.020727467, 0.012818593999999999, 0.015766995, 0.037833528, 0.027251531000000002, 0.037763271, 0.021878706, 0.02015972, 0.0, 1.0, 0.0], [2.100444532, 1.145940536, 1.728731785, 3.899317545, 2.090201323, 2.548716711, 1.470471937, 1.490279636, 0.021329631, 0.013454434, 0.016147768, 0.035763899, 0.030578566, 0.04183058, 0.021501488, 0.019250236, 0.0, 0.0, 1.0], [2.950306345, 1.5112896869999999, 2.351607911, 4.731467011, 1.844465518, 2.8097365919999997, 1.8469653999999998, 1.8221058890000001, 0.032024492, 0.020151901, 0.025172227000000002, 0.04863947, 0.028960559, 0.050921875, 0.029782492999999997, 0.025358044, 0.0, 0.0, 0.0], [2.102302496, 1.109463457, 1.705676405, 3.616016015, 1.791049615, 1.646659203, 1.418525725, 1.3888054969999999, 0.021478547, 0.012800691999999999, 0.015618525, 0.033200917, 0.026599995, 0.028638972000000002, 0.020914088, 0.018524923000000002, 1.0, 0.0, 0.0], [2.71240515, 1.4929354719999999, 2.037856136, 4.492740438999999, 2.14143484, 2.370186831, 1.613093361, 1.666366535, 0.027663543, 0.01610518, 0.018982566, 0.04186714, 0.031638377999999995, 0.040042414, 0.024349519, 0.021777669, 0.0, 1.0, 0.0], [2.806773445, 1.463186589, 1.978532512, 4.272488002, 2.15054073, 2.434767109, 1.579509152, 1.616402673, 0.027302765, 0.015191162, 0.017643575, 0.038611802, 0.030754264, 0.041062417999999996, 0.023169526, 0.021261692999999998, 0.0, 0.0, 1.0], [3.996018489, 1.6937410990000001, 2.457377475, 4.9314181569999995, 2.467602943, 2.653987226, 1.941020203, 1.919451741, 0.041623635, 0.020876597, 0.024884077, 0.048516993, 0.037086084, 0.046920558, 0.030469911000000002, 0.026594392999999997, 0.0, 0.0, 0.0], [3.4919227389999996, 1.371221796, 1.822524345, 3.8395642160000003, 1.832072135, 2.107395976, 1.466208683, 1.489599191, 0.031391494, 0.014649446000000002, 0.015725275, 0.033623531, 0.025181705, 0.034567709, 0.020801503, 0.018916638, 1.0, 0.0, 0.0], [4.138746178, 1.590298524, 2.127644349, 4.647601911000001, 2.207977833, 2.408107358, 1.6567405, 1.808700776, 0.03801741, 0.016360663, 0.018774042, 0.041971937, 0.030086376, 0.039289914, 0.024006706000000003, 0.022964617000000003, 0.0, 1.0, 0.0], [4.1102648, 1.413394071, 2.171428253, 4.313140403, 2.25399625, 2.501084059, 1.5605549619999999, 1.7292178530000002, 0.035945307, 0.014816628, 0.017854382, 0.037720308, 0.028958071000000002, 0.038035442, 0.021590911, 0.021213376000000003, 0.0, 0.0, 1.0], [5.3423692130000004, 2.00075645, 2.8467697339999996, 5.457938977, 2.7274992539999996, 2.882841115, 2.016101307, 2.097992338, 0.05060135599999999, 0.021980033, 0.026668149, 0.051271779, 0.037292837, 0.046894320999999996, 0.029647358, 0.027521142999999998, 0.0, 0.0, 0.0], [3.956125825, 1.477522992, 2.1046672280000003, 4.006806535, 2.087855185, 2.266268641, 1.5526879530000002, 1.69326927, 0.033299168, 0.014485602, 0.016926185, 0.033827733, 0.026473228999999997, 0.034031301, 0.020589085, 0.019857018, 1.0, 0.0, 0.0], [4.763929699, 1.912734514, 2.551680824, 5.157811215, 2.3268884130000003, 2.507616351, 1.7143166930000002, 1.96068761, 0.04122608, 0.018901137, 0.021483049, 0.044278385999999996, 0.029460871, 0.037686629, 0.023066397000000002, 0.023404855, 0.0, 1.0, 0.0], [4.564415962, 1.8742060980000002, 2.417241362, 4.75919687, 2.124588573, 2.49197152, 1.564198999, 1.7978283769999999, 0.04010755, 0.018650783, 0.020088736, 0.041375586, 0.027380357, 0.036819795, 0.021294501, 0.022152966, 0.0, 0.0, 1.0], [5.704490786, 2.007193849, 3.025910542, 5.569184022999999, 2.599997065, 3.031412455, 1.9779736369999998, 2.3238615, 0.056545075, 0.024509824, 0.029932182, 0.054783693, 0.036797021, 0.048806145999999995, 0.029623739, 0.030200502, 0.0, 0.0, 0.0], [4.385355054, 1.7750965840000001, 2.2427225319999997, 4.281760421, 1.995153671, 2.434018311, 1.559940404, 1.7794966030000001, 0.038307006, 0.016830618999999998, 0.01913533, 0.038573799, 0.026450223, 0.037180845, 0.021246899, 0.022488879, 1.0, 0.0, 0.0], [5.198388115, 2.054540531, 2.7105293510000004, 5.345565553, 2.235673312, 2.645565742, 1.703697234, 2.027513482, 0.045321600999999996, 0.01982691, 0.023413355, 0.048358423, 0.029267598, 0.038615263999999996, 0.023635697999999997, 0.024718666, 0.0, 1.0, 0.0], [4.726878735, 1.866208461, 2.416651679, 4.941684604, 2.2316608909999998, 2.6547225759999997, 1.6940473930000002, 1.8605346459999998, 0.042128195, 0.017842783, 0.021378209, 0.045150000999999995, 0.02925597, 0.038893428, 0.023701022999999998, 0.022907713, 0.0, 0.0, 1.0], [5.801207521, 2.0039953169999998, 2.9539581139999997, 5.991519009, 2.622331047, 3.104103865, 2.027109472, 2.446351267, 0.057742090999999995, 0.024099633, 0.03020391, 0.060468975999999994, 0.037410631, 0.04979212400000001, 0.031075531, 0.032133494, 0.0, 0.0, 0.0], [4.129633203, 1.7628645530000002, 2.106648395, 4.287681536, 1.907463709, 2.2017321819999998, 1.4377587980000002, 1.676535962, 0.037606104, 0.017841951, 0.018735316000000002, 0.040128887, 0.025712252, 0.033209466, 0.020881293999999998, 0.021654631, 1.0, 0.0, 0.0], [4.871555342, 1.848641495, 2.457460101, 5.464322415, 2.217377918, 2.5115484390000002, 1.675566122, 2.155406868, 0.045182755, 0.020011138, 0.022527934, 0.050680244000000006, 0.030097044, 0.037567156000000004, 0.024374032, 0.026638902000000002, 0.0, 1.0, 0.0], [4.807364778999999, 1.781217865, 2.467571628, 5.099299198, 2.137790006, 2.39502626, 1.616507614, 1.957874467, 0.043008935, 0.017975064, 0.021573519, 0.048524019, 0.028405707000000002, 0.035661935, 0.022724618999999998, 0.024828512, 0.0, 0.0, 1.0], [6.137761782, 2.325592972, 3.068304416, 6.60593124, 2.6544308780000003, 2.9649689360000004, 1.9716900819999998, 2.4838355659999998, 0.061144817000000004, 0.027157326000000002, 0.031044887, 0.066139217, 0.037727565, 0.046842577999999996, 0.030086208, 0.032705325, 0.0, 0.0, 0.0], [4.934004151, 2.062502975, 2.535160353, 5.058711702, 2.10704148, 2.3244781580000002, 1.5574535040000002, 1.9222939940000001, 0.040643427, 0.019756041999999998, 0.020668155, 0.04424685, 0.026691507000000003, 0.0335161, 0.021048274, 0.02275978, 1.0, 0.0, 0.0], [5.951752429, 2.563358243, 2.955795347, 6.874571798, 2.376626615, 2.5039389880000003, 1.8146669640000002, 2.391335099, 0.051194878, 0.025470903, 0.025476258999999998, 0.061587229, 0.030522234, 0.035467253, 0.024644739, 0.028081801, 0.0, 1.0, 0.0]]'

json.dumps(t.values.tolist()) 
Out[111]: '[[5.440223458999999, 2.313103303, 2.752490588, 6.106937203999999, 2.262817608, 2.5336477000000004, 1.7491891240000002, 2.210720942, 0.047731258, 0.022545073, 0.024259297000000003, 0.056190795, 0.029918652999999996, 0.035514002, 0.024006705, 0.026801102, 0.0, 0.0, 1.0]]'

0 个答案:

没有答案