XGBoost获得实现最大AUC的迭代次数

时间:2018-04-05 21:33:32

标签: r xgboost

我将浏览以下代码:取自here,这只是一个xgb.cv参数优化函数。

system.time(
rmseErrorsHyperparameters <- apply(searchGridSubCol, 1, function(parameterList){

  #Extract Parameters to test
  currentSubsampleRate <- parameterList[["subsample"]]
  currentColsampleRate <- parameterList[["colsample_bytree"]]
  currentDepth <- parameterList[["max_depth"]]
  currentEta <- parameterList[["eta"]]
  currentMinChild <- parameterList[["min_child"]]
  xgboostModelCV <- xgb.cv(data =  dtrain, nrounds = ntrees, nfold = 2, showsd = TRUE, 
                       metrics = "rmse", verbose = TRUE, "eval_metric" = "rmse",
                     "objective" = "reg:linear", "max.depth" = currentDepth, "eta" = currentEta,                               
                     "subsample" = currentSubsampleRate, "colsample_bytree" = currentColsampleRate
                      , print_every_n = 10, "min_child_weight" = currentMinChild, booster = "gbtree",
                     early_stopping_rounds = 10)

  xvalidationScores <- as.data.frame(xgboostModelCV$evaluation_log)
  ****rmse <- tail(xvalidationScores$test_rmse_mean, 1)**
  **trmse <- tail(xvalidationScores$train_rmse_mean,1)****
  output <- return(c(rmse, trmse, currentSubsampleRate, currentColsampleRate, currentDepth, currentEta, currentMinChild))

}))

我正在尝试修改****之间的代码。我想要做的是将tail替换为max,如下所示;

  rmse <- max(xvalidationScores$test_rmse_mean)
  trmse <- max(xvalidationScores$train_rmse_mean)

但也保存max发生的迭代。因此,如果我将ntrees设置为100并且在迭代rmse获得了最大87,那么我希望将此值与迭代一起保存(87它出现了。

稍后代码如下:

output <- as.data.frame(t(rmseErrorsHyperparameters))
head(output)
varnames <- c("TestRMSE", "TrainRMSE", "SubSampRate", "ColSampRate", "Depth", "eta", "currentMinChild")
names(output) <- varnames

如下所示;

TestRMSE    TrainRMSE   SubSampRate ColSampRate Depth   eta currentMinChild
96.07530    96.07417    0.5 0.5 3   0.01    1
96.07458    96.07509    0.6 0.5 3   0.01    1
96.07807    96.07794    0.5 0.6 3   0.01    1
96.07458    96.07557    0.6 0.6 3   0.01    1
96.07829    96.07875    0.5 0.5 4   0.01    1
96.07221    96.07182    0.6 0.5 4   0.01    1

我要做的是添加一个额外的列,在迭代中获得max值。

我希望自己清楚明白。

编辑:

我要修改的内容是模型运行后,取列testAUC找到max值并返回iter发生的max

我粘贴了我正在处理的输出的dput(注意:我将AUC用于我的特定问题,但有助于知道如何在RMSE的函数内获取此信息)

list(structure(list(0.9196126, 0.9033623, 0.5270572, 0.5289016, 
    0.1631758, 0.1662138, iter = c(1, 2, 3, 4, 5, 6, 7, 8, 9, 
    10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 
    25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 
    40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50), train_auc_mean = c(0.8951437, 
    0.8965371, 0.8978626, 0.8988707, 0.9001917, 0.9016094, 0.9024459, 
    0.9034401, 0.9038199, 0.9042312, 0.905216, 0.9054907, 0.9060906, 
    0.9065405, 0.9070835, 0.9078622, 0.9085047, 0.9092435, 0.9096677, 
    0.9103009, 0.9108435, 0.9114527, 0.9117993, 0.9123977, 0.9127984, 
    0.9131131, 0.9134794, 0.9137633, 0.9141034, 0.9144531, 0.9147366, 
    0.9150188, 0.9153703, 0.915693, 0.9159701, 0.9162106, 0.9164605, 
    0.91667, 0.916952, 0.9172335, 0.9174861, 0.9176767, 0.9179821, 
    0.9181687, 0.9184403, 0.9186785, 0.9189445, 0.9191476, 0.9193911, 
    0.9196126), train_auc_std = c(0.000873146499706188, 0.0013455819150254, 
    0.00166724126631134, 0.00229178960858422, 0.00211358610182748, 
    0.0021685266057748, 0.00178882483490829, 0.0014330867001235, 
    0.00133037614601899, 0.00142908416826212, 0.00157062987366438, 
    0.00149423853857791, 0.00158350732239001, 0.00175414020251288, 
    0.00154147242919222, 0.00130750822564675, 0.0011531176913531, 
    0.0012810284345129, 0.00107513022934053, 0.00118107725829082, 
    0.00108282835659572, 0.00095217813982074, 0.000916731045586615, 
    0.000914402652021944, 0.000925037534358363, 0.000933982596136908, 
    0.000968171596397885, 0.000996830080752117, 0.000999102317069, 
    0.000957599858991975, 0.000912228830929314, 0.000911015894471406, 
    0.000830530318576475, 0.00086741697008366, 0.00080877184050021, 
    0.000851959529597412, 0.000820639293489715, 0.000882399229361504, 
    0.000820522028932439, 0.000808661146575655, 0.00077692206173079, 
    0.000769767503869777, 0.000840192412513504, 0.000893997097236334, 
    0.00085496456651761, 0.000847552977747195, 0.000842637674261389, 
    0.000835508252579566, 0.000825488879407905, 0.000849932844426632
    ), train_logloss_mean = c(0.688288, 0.6835267, 0.6788556, 
    0.6742688, 0.6697813, 0.665378, 0.6610398, 0.6567732, 0.652598, 
    0.6484957, 0.644454, 0.6404987, 0.6366067, 0.6327767, 0.6290132, 
    0.6252961, 0.6216444, 0.6180431, 0.6145105, 0.6110207, 0.6075911, 
    0.6042145, 0.6008939, 0.5975965, 0.5943897, 0.5912184, 0.5880742, 
    0.5849882, 0.5819479, 0.5789466, 0.5759881, 0.5730966, 0.5702313, 
    0.5674024, 0.5646118, 0.5618702, 0.5591583, 0.5564953, 0.5538569, 
    0.551244, 0.5486734, 0.5461692, 0.5436563, 0.5412049, 0.5387673, 
    0.5363635, 0.5339838, 0.5316552, 0.5293452, 0.5270572), train_logloss_std = c(2.99566351190152e-05, 
    5.34697108016365e-05, 7.323141406123e-05, 0.000100988910298054, 
    0.000116763050504875, 0.000136931369482075, 0.000156058835328824, 
    0.00015649587866062, 0.000172431435629269, 0.000191563070610692, 
    0.000193559293379142, 0.000203347510654373, 0.000215757757692594, 
    0.000218031672188399, 0.000243628323581981, 0.000235267698553247, 
    0.000255690907102828, 0.000264689044000474, 0.000285014473441321, 
    0.000286014352719424, 0.000297951489435599, 0.000305110226016343, 
    0.000305255122802726, 0.000323379111926704, 0.000337747257559746, 
    0.00035705691416933, 0.000357268750415862, 0.000375639667841919, 
    0.000381227346929675, 0.000411041652483442, 0.000407012395390132, 
    0.000423362303532172, 0.000422746271412242, 0.000423491251426115, 
    0.000445850602814743, 0.000456149931454845, 0.000469240460804438, 
    0.000483815057580709, 0.000489986214114286, 0.000486560582135563, 
    0.000498071721766642, 0.000503954124841367, 0.000513733208988433, 
    0.000517833844033052, 0.000526785354830916, 0.000504555299266918, 
    0.000531380240476864, 0.000504611494165455, 0.000509877200954216, 
    0.000497739650806203), train_error_mean = c(0.2136593, 0.2115807, 
    0.2094728, 0.2058215, 0.2010147, 0.2011363, 0.1976663, 0.1963016, 
    0.1947679, 0.1944829, 0.1950778, 0.1926892, 0.1922478, 0.1925048, 
    0.1926181, 0.1916276, 0.1913063, 0.1898618, 0.1895588, 0.1876229, 
    0.1854967, 0.1847466, 0.1834166, 0.1827868, 0.1812264, 0.1811634, 
    0.1802512, 0.1795696, 0.1783446, 0.1770231, 0.1767397, 0.1762297, 
    0.1752729, 0.1746709, 0.1739863, 0.1735267, 0.1723827, 0.1719105, 
    0.1717445, 0.1709158, 0.1700597, 0.1698614, 0.1684056, 0.1680232, 
    0.1671723, 0.1666052, 0.1652125, 0.1646313, 0.1641831, 0.1631758
    ), train_error_std = c(0.0162186588659482, 0.0166668028250771, 
    0.016164905120662, 0.0154685546593727, 0.0140211525849342, 
    0.0153654141372761, 0.0114450236526621, 0.0104903397771476, 
    0.0101908302061216, 0.0105073437409266, 0.00951226085428692, 
    0.0104100300460661, 0.00956246733641493, 0.0105556220072529, 
    0.0104740322841777, 0.00975581748701749, 0.00972889988693526, 
    0.00905565947681336, 0.00811476540387954, 0.00769535982849394, 
    0.00694964026479031, 0.00679897888215549, 0.00586689270738753, 
    0.0057987687969083, 0.00558790694625423, 0.00548896125327865, 
    0.0058105301272773, 0.00520588054415421, 0.00556398600285821, 
    0.00461605928146561, 0.00440342818835443, 0.00422396041766551, 
    0.00443108845431895, 0.00436686478952625, 0.00382032069465461, 
    0.00399028547976202, 0.00390278308518453, 0.00443162686267732, 
    0.00428972738644375, 0.0045192104133351, 0.00449508178012347, 
    0.00420268264802401, 0.00448433832800408, 0.00430475734043122, 
    0.00393493812022508, 0.00385852199164364, 0.00405450323097614, 
    0.00428838281989813, 0.00386824460059093, 0.00348298627042897
    ), test_auc_mean = c(0.8811229, 0.8821012, 0.8830099, 0.8842433, 
    0.8855852, 0.8866878, 0.8877955, 0.8888536, 0.8891683, 0.8894846, 
    0.8903621, 0.8906483, 0.8912014, 0.8914397, 0.8919142, 0.8926466, 
    0.8933535, 0.8940554, 0.8943622, 0.8950556, 0.8956304, 0.8962909, 
    0.896744, 0.8972643, 0.8975589, 0.8979465, 0.8982284, 0.8985075, 
    0.898786, 0.8991277, 0.8993767, 0.8995949, 0.8997907, 0.9001205, 
    0.9003218, 0.9004286, 0.9007523, 0.9009127, 0.9011547, 0.9013411, 
    0.9015791, 0.9017718, 0.9020105, 0.9021692, 0.9024216, 0.9026098, 
    0.9028157, 0.902953, 0.9031686, 0.9033623), test_auc_std = c(0.00754330327441656, 
    0.00754328296963419, 0.00806933700436532, 0.00790173398754956, 
    0.00667076577014675, 0.00650584392373234, 0.00638239292508034, 
    0.00654034346498086, 0.00636851100414996, 0.00645713921486273, 
    0.00588706627192151, 0.00644870064510048, 0.00604052243104865, 
    0.00604576752862067, 0.00580967446248283, 0.0059154939472606, 
    0.00591113505597135, 0.00602747522931076, 0.00617529822437917, 
    0.00612948071536362, 0.00587437571491952, 0.00618200494095612, 
    0.00603584683370449, 0.00602057940815899, 0.00616550108993821, 
    0.00590366130888615, 0.00599993273628292, 0.0058843139829564, 
    0.00601529726614352, 0.00593407575028349, 0.00598719293241462, 
    0.00597193897239703, 0.00597705621606727, 0.00603783070730093, 
    0.00595028283026266, 0.00600288719200781, 0.00608745137229766, 
    0.00596840690720039, 0.00591811708316459, 0.00598834670756388, 
    0.00600767183608287, 0.00599411786337455, 0.00608440926055143, 
    0.00610958063699259, 0.00609401368885199, 0.00613583449580456, 
    0.00613656653268947, 0.00626081916364727, 0.00621396780165399, 
    0.00625697591573605), test_logloss_mean = c(0.6883233, 0.6836026, 
    0.6789709, 0.6744245, 0.6699811, 0.6656353, 0.6613301, 0.657108, 
    0.652976, 0.6489123, 0.6449008, 0.6409902, 0.6371309, 0.6333362, 
    0.6296182, 0.625937, 0.6223281, 0.6187571, 0.6152755, 0.6118127, 
    0.6084108, 0.6050751, 0.6017758, 0.598529, 0.5953517, 0.5922165, 
    0.5891215, 0.5860721, 0.583079, 0.5801211, 0.5772012, 0.5743363, 
    0.5715108, 0.56872, 0.5659637, 0.5632635, 0.5605788, 0.5579482, 
    0.5553503, 0.5527782, 0.55024, 0.5477804, 0.5453004, 0.5428688, 
    0.5404586, 0.5380859, 0.5357354, 0.5334368, 0.5311633, 0.5289016
    ), test_logloss_std = c(5.68613228217379e-05, 0.000109613138236144, 
    0.000154851186578381, 0.000202748242988687, 0.000252505623714928, 
    0.000297037388133673, 0.000343537028504947, 0.00039417280473763, 
    0.000445152108809884, 0.000490523200283454, 0.000536685159023661, 
    0.0005805109473879, 0.000627699840607671, 0.000673758829247374, 
    0.000710880693211346, 0.000738838006625038, 0.00079860033178686, 
    0.000843823494603547, 0.000901632990750814, 0.000930852625294849, 
    0.000972119622271643, 0.00101327710425968, 0.00104329600787638, 
    0.00108068654106726, 0.00113131675935848, 0.00117079445247857, 
    0.00119185336764085, 0.0012290305488644, 0.00125289776117929, 
    0.00128907559512411, 0.00132504330496146, 0.00132748589822506, 
    0.00135826829455953, 0.00137503134509442, 0.00141204476205794, 
    0.00145400538168406, 0.00147068200505347, 0.00149451054194676, 
    0.00150767503462061, 0.00153264808746057, 0.00155560971970793, 
    0.00159670831401678, 0.00163016405309845, 0.00163873816091979, 
    0.00170482827286317, 0.00170039315748469, 0.00176976649308563, 
    0.0018018847798633, 0.00180684720160447, 0.00181675981903026
    ), test_error_mean = c(0.2154208, 0.2136102, 0.2114981, 0.207626, 
    0.203138, 0.2036158, 0.1998065, 0.1983859, 0.1969524, 0.1964246, 
    0.1972165, 0.1945513, 0.194086, 0.194501, 0.1944506, 0.1935582, 
    0.193269, 0.1920118, 0.191534, 0.1895979, 0.1878002, 0.1868699, 
    0.1862034, 0.1854242, 0.1836765, 0.1835887, 0.1826585, 0.1820548, 
    0.1805965, 0.1794273, 0.1789999, 0.1783083, 0.177818, 0.1771392, 
    0.1765609, 0.1759952, 0.1751652, 0.1749138, 0.1747126, 0.1739585, 
    0.1729149, 0.1729148, 0.1714439, 0.1714063, 0.1704382, 0.1699102, 
    0.1682128, 0.1676346, 0.1671192, 0.1662138), test_error_std = c(0.0159501571014207, 
    0.0163406308311522, 0.0153050457068903, 0.014898956648034, 
    0.0123790320219314, 0.0142987541681082, 0.00940067057448582, 
    0.0088897350292345, 0.00897252415098445, 0.00948333106244869, 
    0.00889585276687978, 0.0091076255418198, 0.00887665868443747, 
    0.00955196525328665, 0.00961350794663399, 0.00919105141754684, 
    0.00905392616492976, 0.00836850337635058, 0.00725348700970739, 
    0.00717273843730541, 0.00623918106485095, 0.00592587428908145, 
    0.00519975147867688, 0.00519286860607813, 0.00448201265616255, 
    0.00465992927092304, 0.00478653503591049, 0.0044566891029114, 
    0.00439628043350302, 0.00417364188808735, 0.00360300248265351, 
    0.00380256084895415, 0.00359530413178045, 0.0035239984335982, 
    0.0033875729497689, 0.00366222912445446, 0.00349675529026618, 
    0.00399751942584358, 0.00390744927542223, 0.00395275595629218, 
    0.00375911754139128, 0.00325926435871614, 0.00359755951861786, 
    0.00353682479209724, 0.0031308523695628, 0.00366587181445227, 
    0.00373223744689414, 0.004797147823447, 0.00457111983653927, 
    0.00401040810392153), 1, 1, 5, 0.01, 0, 1, 0, 0, 1, 1), .Names = c("", 
"", "", "", "", "", "iter", "train_auc_mean", "train_auc_std", 
"train_logloss_mean", "train_logloss_std", "train_error_mean", 
"train_error_std", "test_auc_mean", "test_auc_std", "test_logloss_mean", 
"test_logloss_std", "test_error_mean", "test_error_std", "", 
"", "", "", "", "", "", "", "", "")))

1 个答案:

答案 0 :(得分:0)

如果您的输出保存到名为res的对象,然后使用str(),则可以看到res是一个列表。此外,res也只包含一个也是列表的元素。

检查自己
str(res)
str(res[[1]])

您想要的值似乎存储在res[[1]]$test_auc_mean。因此,如果您仅在导致最大test_auc_mean的迭代索引之后,则可以键入

which.max(res[[1]]$test_auc_mean)
# [1] 50

在这种情况下,我认为不需要为最大迭代添加列。您可以通过键入

来查找test_auc_mean的最大值
max(res[[1]]$test_auc_mean)
# [1] 0.9034