在R

时间:2018-03-08 13:17:00

标签: r h2o

对任何算法/方法使用h2o.grid时,使用fold_column参数会导致错误。有人知道为什么会这样吗?当我只使用nfolds时,h2o.grid工作,估计 - 不使用网格函数 - 在使用fold_column时也有效,但是当我组合h2o.grid和fold_column时,即使它应该工作(http://docs.h2o.ai/h2o/latest-stable/h2o-docs/grid-search.html

library(h2o)

#create random data
N = 1000

y = rnorm(N)
x1 = y + rnorm(N)
x2 = -y + 0.5*rnorm(N)
x3 = rnorm(N)

Data = data.table(y = y, 
                  x1 = x1, 
                  x2 = x2, 
                  x3 = x3, 
                  Split = rep(1:5, N/5), 
                  Weight = runif(N))
setkey(Data, Split)


#h2o
h2o.init(
  nthreads=-1,            ## -1: use all available threads
  max_mem_size = "4G")    ## specify the memory size for the H2O cloud
h2o.removeAll() # Clean slate - just in case the cluster was already running


bla = h2o.assign(as.h2o(Data), "bla.hex")


#make simple grid
Grid = list(ntrees = seq(1, 5, by=1))


#not working: grid with fold_column
GridResult = h2o.grid("gbm", 
                      x = c("x1", "x2", "x3"), 
                      y = "y",
                      training_frame = bla,
                      hyper_params = Grid,
                      keep_cross_validation_predictions = TRUE,
                      weights_column = "Weight",
                      fold_column = "Split"
)

#working: grid with nfolds
GridResult = h2o.grid("gbm", 
                      x = c("x1", "x2", "x3"), 
                      y = "y",
                      training_frame = bla,
                      hyper_params = Grid,
                      keep_cross_validation_predictions = TRUE,
                      weights_column = "Weight",
                      nfolds = 5
)

我的关键信息是 R版本3.3.1(2016-06-21) 平台:x86_64-w64-mingw32 / x64(64位) 运行于:Windows 7 x64(内部版本7601)Service Pack 1 h2o_3.8.2.6

错误:

[2018-03-08 14:14:38] failure_details: NA 
[2018-03-08 14:14:38] failure_stack_traces: java.lang.NullPointerException
    at hex.ModelBuilder.nFoldWork(ModelBuilder.java:209)
    at hex.ModelBuilder.computeCrossValidation(ModelBuilder.java:224)
    at hex.ModelBuilder.trainModelNested(ModelBuilder.java:186)
    at hex.grid.GridSearch.startBuildModel(GridSearch.java:329)
    at hex.grid.GridSearch.buildModel(GridSearch.java:311)
    at hex.grid.GridSearch.gridSearch(GridSearch.java:215)
    at hex.grid.GridSearch.access$000(GridSearch.java:69)
    at hex.grid.GridSearch$1.compute2(GridSearch.java:136)
    at water.H2O$H2OCountedCompleter.compute(H2O.java:1194)
    at jsr166y.CountedCompleter.exec(CountedCompleter.java:468)
    at jsr166y.ForkJoinTask.doExec(ForkJoinTask.java:263)
    at jsr166y.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:974)
    at jsr166y.ForkJoinPool.runWorker(ForkJoinPool.java:1477)
    at jsr166y.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)

0 个答案:

没有答案