我将浏览以下代码:取自here,这只是一个xgb.cv
参数优化函数。
system.time(
rmseErrorsHyperparameters <- apply(searchGridSubCol, 1, function(parameterList){
#Extract Parameters to test
currentSubsampleRate <- parameterList[["subsample"]]
currentColsampleRate <- parameterList[["colsample_bytree"]]
currentDepth <- parameterList[["max_depth"]]
currentEta <- parameterList[["eta"]]
currentMinChild <- parameterList[["min_child"]]
xgboostModelCV <- xgb.cv(data = dtrain, nrounds = ntrees, nfold = 2, showsd = TRUE,
metrics = "rmse", verbose = TRUE, "eval_metric" = "rmse",
"objective" = "reg:linear", "max.depth" = currentDepth, "eta" = currentEta,
"subsample" = currentSubsampleRate, "colsample_bytree" = currentColsampleRate
, print_every_n = 10, "min_child_weight" = currentMinChild, booster = "gbtree",
early_stopping_rounds = 10)
xvalidationScores <- as.data.frame(xgboostModelCV$evaluation_log)
****rmse <- tail(xvalidationScores$test_rmse_mean, 1)**
**trmse <- tail(xvalidationScores$train_rmse_mean,1)****
output <- return(c(rmse, trmse, currentSubsampleRate, currentColsampleRate, currentDepth, currentEta, currentMinChild))
}))
我正在尝试修改****之间的代码。我想要做的是将tail
替换为max
,如下所示;
rmse <- max(xvalidationScores$test_rmse_mean)
trmse <- max(xvalidationScores$train_rmse_mean)
但也保存max
发生的迭代。因此,如果我将ntrees
设置为100
并且在迭代rmse
获得了最大87
,那么我希望将此值与迭代一起保存(87
它出现了。
稍后代码如下:
output <- as.data.frame(t(rmseErrorsHyperparameters))
head(output)
varnames <- c("TestRMSE", "TrainRMSE", "SubSampRate", "ColSampRate", "Depth", "eta", "currentMinChild")
names(output) <- varnames
如下所示;
TestRMSE TrainRMSE SubSampRate ColSampRate Depth eta currentMinChild
96.07530 96.07417 0.5 0.5 3 0.01 1
96.07458 96.07509 0.6 0.5 3 0.01 1
96.07807 96.07794 0.5 0.6 3 0.01 1
96.07458 96.07557 0.6 0.6 3 0.01 1
96.07829 96.07875 0.5 0.5 4 0.01 1
96.07221 96.07182 0.6 0.5 4 0.01 1
我要做的是添加一个额外的列,在迭代中获得max
值。
我希望自己清楚明白。
编辑:
我要修改的内容是模型运行后,取列testAUC
找到max
值并返回iter
发生的max
我粘贴了我正在处理的输出的dput
(注意:我将AUC用于我的特定问题,但有助于知道如何在RMSE的函数内获取此信息)
list(structure(list(0.9196126, 0.9033623, 0.5270572, 0.5289016,
0.1631758, 0.1662138, iter = c(1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50), train_auc_mean = c(0.8951437,
0.8965371, 0.8978626, 0.8988707, 0.9001917, 0.9016094, 0.9024459,
0.9034401, 0.9038199, 0.9042312, 0.905216, 0.9054907, 0.9060906,
0.9065405, 0.9070835, 0.9078622, 0.9085047, 0.9092435, 0.9096677,
0.9103009, 0.9108435, 0.9114527, 0.9117993, 0.9123977, 0.9127984,
0.9131131, 0.9134794, 0.9137633, 0.9141034, 0.9144531, 0.9147366,
0.9150188, 0.9153703, 0.915693, 0.9159701, 0.9162106, 0.9164605,
0.91667, 0.916952, 0.9172335, 0.9174861, 0.9176767, 0.9179821,
0.9181687, 0.9184403, 0.9186785, 0.9189445, 0.9191476, 0.9193911,
0.9196126), train_auc_std = c(0.000873146499706188, 0.0013455819150254,
0.00166724126631134, 0.00229178960858422, 0.00211358610182748,
0.0021685266057748, 0.00178882483490829, 0.0014330867001235,
0.00133037614601899, 0.00142908416826212, 0.00157062987366438,
0.00149423853857791, 0.00158350732239001, 0.00175414020251288,
0.00154147242919222, 0.00130750822564675, 0.0011531176913531,
0.0012810284345129, 0.00107513022934053, 0.00118107725829082,
0.00108282835659572, 0.00095217813982074, 0.000916731045586615,
0.000914402652021944, 0.000925037534358363, 0.000933982596136908,
0.000968171596397885, 0.000996830080752117, 0.000999102317069,
0.000957599858991975, 0.000912228830929314, 0.000911015894471406,
0.000830530318576475, 0.00086741697008366, 0.00080877184050021,
0.000851959529597412, 0.000820639293489715, 0.000882399229361504,
0.000820522028932439, 0.000808661146575655, 0.00077692206173079,
0.000769767503869777, 0.000840192412513504, 0.000893997097236334,
0.00085496456651761, 0.000847552977747195, 0.000842637674261389,
0.000835508252579566, 0.000825488879407905, 0.000849932844426632
), train_logloss_mean = c(0.688288, 0.6835267, 0.6788556,
0.6742688, 0.6697813, 0.665378, 0.6610398, 0.6567732, 0.652598,
0.6484957, 0.644454, 0.6404987, 0.6366067, 0.6327767, 0.6290132,
0.6252961, 0.6216444, 0.6180431, 0.6145105, 0.6110207, 0.6075911,
0.6042145, 0.6008939, 0.5975965, 0.5943897, 0.5912184, 0.5880742,
0.5849882, 0.5819479, 0.5789466, 0.5759881, 0.5730966, 0.5702313,
0.5674024, 0.5646118, 0.5618702, 0.5591583, 0.5564953, 0.5538569,
0.551244, 0.5486734, 0.5461692, 0.5436563, 0.5412049, 0.5387673,
0.5363635, 0.5339838, 0.5316552, 0.5293452, 0.5270572), train_logloss_std = c(2.99566351190152e-05,
5.34697108016365e-05, 7.323141406123e-05, 0.000100988910298054,
0.000116763050504875, 0.000136931369482075, 0.000156058835328824,
0.00015649587866062, 0.000172431435629269, 0.000191563070610692,
0.000193559293379142, 0.000203347510654373, 0.000215757757692594,
0.000218031672188399, 0.000243628323581981, 0.000235267698553247,
0.000255690907102828, 0.000264689044000474, 0.000285014473441321,
0.000286014352719424, 0.000297951489435599, 0.000305110226016343,
0.000305255122802726, 0.000323379111926704, 0.000337747257559746,
0.00035705691416933, 0.000357268750415862, 0.000375639667841919,
0.000381227346929675, 0.000411041652483442, 0.000407012395390132,
0.000423362303532172, 0.000422746271412242, 0.000423491251426115,
0.000445850602814743, 0.000456149931454845, 0.000469240460804438,
0.000483815057580709, 0.000489986214114286, 0.000486560582135563,
0.000498071721766642, 0.000503954124841367, 0.000513733208988433,
0.000517833844033052, 0.000526785354830916, 0.000504555299266918,
0.000531380240476864, 0.000504611494165455, 0.000509877200954216,
0.000497739650806203), train_error_mean = c(0.2136593, 0.2115807,
0.2094728, 0.2058215, 0.2010147, 0.2011363, 0.1976663, 0.1963016,
0.1947679, 0.1944829, 0.1950778, 0.1926892, 0.1922478, 0.1925048,
0.1926181, 0.1916276, 0.1913063, 0.1898618, 0.1895588, 0.1876229,
0.1854967, 0.1847466, 0.1834166, 0.1827868, 0.1812264, 0.1811634,
0.1802512, 0.1795696, 0.1783446, 0.1770231, 0.1767397, 0.1762297,
0.1752729, 0.1746709, 0.1739863, 0.1735267, 0.1723827, 0.1719105,
0.1717445, 0.1709158, 0.1700597, 0.1698614, 0.1684056, 0.1680232,
0.1671723, 0.1666052, 0.1652125, 0.1646313, 0.1641831, 0.1631758
), train_error_std = c(0.0162186588659482, 0.0166668028250771,
0.016164905120662, 0.0154685546593727, 0.0140211525849342,
0.0153654141372761, 0.0114450236526621, 0.0104903397771476,
0.0101908302061216, 0.0105073437409266, 0.00951226085428692,
0.0104100300460661, 0.00956246733641493, 0.0105556220072529,
0.0104740322841777, 0.00975581748701749, 0.00972889988693526,
0.00905565947681336, 0.00811476540387954, 0.00769535982849394,
0.00694964026479031, 0.00679897888215549, 0.00586689270738753,
0.0057987687969083, 0.00558790694625423, 0.00548896125327865,
0.0058105301272773, 0.00520588054415421, 0.00556398600285821,
0.00461605928146561, 0.00440342818835443, 0.00422396041766551,
0.00443108845431895, 0.00436686478952625, 0.00382032069465461,
0.00399028547976202, 0.00390278308518453, 0.00443162686267732,
0.00428972738644375, 0.0045192104133351, 0.00449508178012347,
0.00420268264802401, 0.00448433832800408, 0.00430475734043122,
0.00393493812022508, 0.00385852199164364, 0.00405450323097614,
0.00428838281989813, 0.00386824460059093, 0.00348298627042897
), test_auc_mean = c(0.8811229, 0.8821012, 0.8830099, 0.8842433,
0.8855852, 0.8866878, 0.8877955, 0.8888536, 0.8891683, 0.8894846,
0.8903621, 0.8906483, 0.8912014, 0.8914397, 0.8919142, 0.8926466,
0.8933535, 0.8940554, 0.8943622, 0.8950556, 0.8956304, 0.8962909,
0.896744, 0.8972643, 0.8975589, 0.8979465, 0.8982284, 0.8985075,
0.898786, 0.8991277, 0.8993767, 0.8995949, 0.8997907, 0.9001205,
0.9003218, 0.9004286, 0.9007523, 0.9009127, 0.9011547, 0.9013411,
0.9015791, 0.9017718, 0.9020105, 0.9021692, 0.9024216, 0.9026098,
0.9028157, 0.902953, 0.9031686, 0.9033623), test_auc_std = c(0.00754330327441656,
0.00754328296963419, 0.00806933700436532, 0.00790173398754956,
0.00667076577014675, 0.00650584392373234, 0.00638239292508034,
0.00654034346498086, 0.00636851100414996, 0.00645713921486273,
0.00588706627192151, 0.00644870064510048, 0.00604052243104865,
0.00604576752862067, 0.00580967446248283, 0.0059154939472606,
0.00591113505597135, 0.00602747522931076, 0.00617529822437917,
0.00612948071536362, 0.00587437571491952, 0.00618200494095612,
0.00603584683370449, 0.00602057940815899, 0.00616550108993821,
0.00590366130888615, 0.00599993273628292, 0.0058843139829564,
0.00601529726614352, 0.00593407575028349, 0.00598719293241462,
0.00597193897239703, 0.00597705621606727, 0.00603783070730093,
0.00595028283026266, 0.00600288719200781, 0.00608745137229766,
0.00596840690720039, 0.00591811708316459, 0.00598834670756388,
0.00600767183608287, 0.00599411786337455, 0.00608440926055143,
0.00610958063699259, 0.00609401368885199, 0.00613583449580456,
0.00613656653268947, 0.00626081916364727, 0.00621396780165399,
0.00625697591573605), test_logloss_mean = c(0.6883233, 0.6836026,
0.6789709, 0.6744245, 0.6699811, 0.6656353, 0.6613301, 0.657108,
0.652976, 0.6489123, 0.6449008, 0.6409902, 0.6371309, 0.6333362,
0.6296182, 0.625937, 0.6223281, 0.6187571, 0.6152755, 0.6118127,
0.6084108, 0.6050751, 0.6017758, 0.598529, 0.5953517, 0.5922165,
0.5891215, 0.5860721, 0.583079, 0.5801211, 0.5772012, 0.5743363,
0.5715108, 0.56872, 0.5659637, 0.5632635, 0.5605788, 0.5579482,
0.5553503, 0.5527782, 0.55024, 0.5477804, 0.5453004, 0.5428688,
0.5404586, 0.5380859, 0.5357354, 0.5334368, 0.5311633, 0.5289016
), test_logloss_std = c(5.68613228217379e-05, 0.000109613138236144,
0.000154851186578381, 0.000202748242988687, 0.000252505623714928,
0.000297037388133673, 0.000343537028504947, 0.00039417280473763,
0.000445152108809884, 0.000490523200283454, 0.000536685159023661,
0.0005805109473879, 0.000627699840607671, 0.000673758829247374,
0.000710880693211346, 0.000738838006625038, 0.00079860033178686,
0.000843823494603547, 0.000901632990750814, 0.000930852625294849,
0.000972119622271643, 0.00101327710425968, 0.00104329600787638,
0.00108068654106726, 0.00113131675935848, 0.00117079445247857,
0.00119185336764085, 0.0012290305488644, 0.00125289776117929,
0.00128907559512411, 0.00132504330496146, 0.00132748589822506,
0.00135826829455953, 0.00137503134509442, 0.00141204476205794,
0.00145400538168406, 0.00147068200505347, 0.00149451054194676,
0.00150767503462061, 0.00153264808746057, 0.00155560971970793,
0.00159670831401678, 0.00163016405309845, 0.00163873816091979,
0.00170482827286317, 0.00170039315748469, 0.00176976649308563,
0.0018018847798633, 0.00180684720160447, 0.00181675981903026
), test_error_mean = c(0.2154208, 0.2136102, 0.2114981, 0.207626,
0.203138, 0.2036158, 0.1998065, 0.1983859, 0.1969524, 0.1964246,
0.1972165, 0.1945513, 0.194086, 0.194501, 0.1944506, 0.1935582,
0.193269, 0.1920118, 0.191534, 0.1895979, 0.1878002, 0.1868699,
0.1862034, 0.1854242, 0.1836765, 0.1835887, 0.1826585, 0.1820548,
0.1805965, 0.1794273, 0.1789999, 0.1783083, 0.177818, 0.1771392,
0.1765609, 0.1759952, 0.1751652, 0.1749138, 0.1747126, 0.1739585,
0.1729149, 0.1729148, 0.1714439, 0.1714063, 0.1704382, 0.1699102,
0.1682128, 0.1676346, 0.1671192, 0.1662138), test_error_std = c(0.0159501571014207,
0.0163406308311522, 0.0153050457068903, 0.014898956648034,
0.0123790320219314, 0.0142987541681082, 0.00940067057448582,
0.0088897350292345, 0.00897252415098445, 0.00948333106244869,
0.00889585276687978, 0.0091076255418198, 0.00887665868443747,
0.00955196525328665, 0.00961350794663399, 0.00919105141754684,
0.00905392616492976, 0.00836850337635058, 0.00725348700970739,
0.00717273843730541, 0.00623918106485095, 0.00592587428908145,
0.00519975147867688, 0.00519286860607813, 0.00448201265616255,
0.00465992927092304, 0.00478653503591049, 0.0044566891029114,
0.00439628043350302, 0.00417364188808735, 0.00360300248265351,
0.00380256084895415, 0.00359530413178045, 0.0035239984335982,
0.0033875729497689, 0.00366222912445446, 0.00349675529026618,
0.00399751942584358, 0.00390744927542223, 0.00395275595629218,
0.00375911754139128, 0.00325926435871614, 0.00359755951861786,
0.00353682479209724, 0.0031308523695628, 0.00366587181445227,
0.00373223744689414, 0.004797147823447, 0.00457111983653927,
0.00401040810392153), 1, 1, 5, 0.01, 0, 1, 0, 0, 1, 1), .Names = c("",
"", "", "", "", "", "iter", "train_auc_mean", "train_auc_std",
"train_logloss_mean", "train_logloss_std", "train_error_mean",
"train_error_std", "test_auc_mean", "test_auc_std", "test_logloss_mean",
"test_logloss_std", "test_error_mean", "test_error_std", "",
"", "", "", "", "", "", "", "", "")))
答案 0 :(得分:0)
如果您的输出保存到名为res
的对象,然后使用str()
,则可以看到res
是一个列表。此外,res
也只包含一个也是列表的元素。
str(res)
str(res[[1]])
您想要的值似乎存储在res[[1]]$test_auc_mean
。因此,如果您仅在导致最大test_auc_mean
的迭代索引之后,则可以键入
which.max(res[[1]]$test_auc_mean)
# [1] 50
在这种情况下,我认为不需要为最大迭代添加列。您可以通过键入
来查找test_auc_mean
的最大值
max(res[[1]]$test_auc_mean)
# [1] 0.9034