我试图使用sparklyr::ml_logistic_regression
来拟合逻辑回归模型。我的训练数据集包含42,457行和785列;响应是label
列中的0/1整数,所有剩余列都是0/1整数特征。我的源数据位于R数据框(df
)中,我可以使用glm(label ~ ., data = df, family = binomial)
在基础R中成功地拟合模型。
很遗憾,我无法将此模型与ml_logistic_regression
相匹配。代码如下; sc
是现有的Spark连接。
library(sparklyr)
library(tidyverse)
copy_to(sc, df, "spark_train", overwrite = TRUE)
train_tbl <- tbl(sc, "spark_train")
fit <- ml_logistic_regression(train_tbl, label ~ .)
这是一个堆栈跟踪:
d> fit <- ml_logistic_regression(train_tbl, label ~ .)
* No rows dropped by 'na.omit' call
Error: java.lang.ArrayIndexOutOfBoundsException: 1
at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:343)
at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:159)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:71)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
at java.lang.reflect.Method.invoke(Unknown Source)
at sparklyr.Invoke$.invoke(invoke.scala:94)
at sparklyr.StreamHandler$.handleMethodCall(stream.scala:89)
at sparklyr.StreamHandler$.read(stream.scala:55)
at sparklyr.BackendHandler.channelRead0(handler.scala:49)
at sparklyr.BackendHandler.channelRead0(handler.scala:14)
at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
at java.lang.Thread.run(Unknown Source)
这是sessionInfo()
:
R version 3.3.2 (2016-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)
locale:
[1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C
[5] LC_TIME=English_United Kingdom.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] dplyr_0.7.1 purrr_0.2.2.2 readr_1.0.0 tidyr_0.6.3
[5] tibble_1.3.3 ggplot2_2.2.1 tidyverse_1.1.1 sparklyr_0.5.6
[9] robomarker_0.1.0 devtools_1.12.0
loaded via a namespace (and not attached):
[1] h2o_3.10.5.2 reshape2_1.4.2 haven_1.0.0 lattice_0.20-34
[5] colorspace_1.3-2 htmltools_0.3.5 yaml_2.1.14 base64enc_0.1-3
[9] rlang_0.1.1 foreign_0.8-67 glue_1.1.1 withr_1.0.2
[13] DBI_0.7 rappdirs_0.3.1 dbplyr_1.0.0 modelr_0.1.0
[17] readxl_1.0.0 bindrcpp_0.2 bindr_0.1 plyr_1.8.4
[21] stringr_1.2.0 munsell_0.4.3 commonmark_1.1 gtable_0.2.0
[25] cellranger_1.1.0 rvest_0.3.2 psych_1.7.3.21 memoise_1.0.0
[29] forcats_0.2.0 httpuv_1.3.3 parallel_3.3.2 broom_0.4.2
[33] Rcpp_0.12.10 xtable_1.8-2 backports_1.0.5 scales_0.4.1
[37] jsonlite_1.2 config_0.2 mime_0.5 mnormt_1.5-5
[41] hms_0.3 digest_0.6.12 stringi_1.1.2 shiny_1.0.3
[45] grid_3.3.2 rprojroot_1.2 bitops_1.0-6 tools_3.3.2
[49] magrittr_1.5 RCurl_1.95-4.8 lazyeval_0.2.0 pkgconfig_2.0.1
[53] xml2_1.1.1 lubridate_1.6.0 assertthat_0.1 roxygen2_6.0.1
[57] httr_1.2.1 rstudioapi_0.6 R6_2.2.0 rsparkling_0.2.0
[61] nlme_3.1-128
知道为什么会这样吗?
答案 0 :(得分:1)
此错误可能是由于训练数据集中只有一种类型的标签引起的。检查以确保您有多种标签类型;根据您的火花版本,您可能只能使用两个标签(即0和1,用于二项式回归)。