sparklyr上的类型错误:列预测必须是org.apache.spark.mllib.linalg.VectorUDT@f71b0bce类型,但实际上是DoubleType

时间:2018-04-25 07:48:16

标签: r apache-spark machine-learning sparklyr

我正在尝试在此网站上关注本教程:https://beta.rstudioconnect.com/content/1518/notebook-classification.html#auc_and_accuracy

我不知道为什么,因为我只是粘贴了网站上的代码。我不知道如何将列转换为正确的类型。有人有解决方案吗? :)

我的数据在分区中,并具有以下形状:

> partition
$train
# Source:   table<sparklyr_tmp_100e145972790> [?? x 9]
# Database: spark_connection
   Survived Pclass Sex      Age SibSp Parch  Fare Embarked Family_Sizes
      <dbl> <chr>  <chr>  <dbl> <dbl> <dbl> <dbl> <chr>    <chr>       
 1       0. 1      female    2.    1.    2. 152.  S        1           
 2       0. 1      female   25.    1.    2. 152.  S        1           
 3       0. 1      female   50.    0.    0.  28.7 C        0           
 4       0. 1      male     18.    1.    0. 109.  C        1           
 5       0. 1      male     19.    1.    0.  53.1 S        1           
 6       0. 1      male     19.    3.    2. 263.  S        2           
 7       0. 1      male     22.    0.    0. 136.  C        0           
 8       0. 1      male     24.    0.    0.  79.2 C        0           
 9       0. 1      male     24.    0.    1. 248.  C        1           
10       0. 1      male     27.    0.    2. 212.  C        1           
# ... with more rows

然后我只应用一个模型,例如逻辑回归。

# Create table references
train_tbl <- partition$train
test_tbl <- partition$test

# Model survival as a function of several predictors
ml_formula <- formula(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + 
Embarked + Family_Sizes)

# Train a logistic regression model
ml_log <- ml_logistic_regression(train_tbl, ml_formula)

# Create a function for scoring
score_test_data <- function(model, data=test_tbl){
  pred <- sdf_predict(model, data)
  select(pred, Survived, prediction)
}

# Calculate the score and AUC metric
ml_score <- score_test_data(ml_log)

现在,ml_score是:

> ml_score
# Source:   lazy query [?? x 2]
# Database: spark_connection
   Survived prediction
      <dbl>      <dbl>
 1       0.         1.
 2       0.         0.
 3       0.         0.
 4       0.         0.
 5       0.         0.
 6       0.         0.
 7       0.         0.
 8       0.         0.
 9       0.         0.
10       0.         0.
# ... with more rows

现在我应用函数ml_binart_classification_eval:

ml_binary_classification_eval(ml_score,'Survived','prediction')

然后我有错误:

Error: java.lang.IllegalArgumentException: requirement failed: Column prediction must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually DoubleType.
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.util.SchemaUtils$.checkColumnType(SchemaUtils.scala:42)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:82)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at sparklyr.Invoke$.invoke(invoke.scala:102)
at sparklyr.StreamHandler$.handleMethodCall(stream.scala:97)
at sparklyr.StreamHandler$.read(stream.scala:62)
at sparklyr.BackendHandler.channelRead0(handler.scala:52)
at sparklyr.BackendHandler.channelRead0(handler.scala:14)
at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
at java.lang.Thread.run(Thread.java:748)

1 个答案:

答案 0 :(得分:1)

在当前版本中,您需要将原始预测列名称传递给ml_binary_classification_evaluator()。默认情况下,"rawPrediction"文档?ml_evaluator不正确,并且已更新。