我正在尝试在此网站上关注本教程:https://beta.rstudioconnect.com/content/1518/notebook-classification.html#auc_and_accuracy
我不知道为什么,因为我只是粘贴了网站上的代码。我不知道如何将列转换为正确的类型。有人有解决方案吗? :)
我的数据在分区中,并具有以下形状:
> partition
$train
# Source: table<sparklyr_tmp_100e145972790> [?? x 9]
# Database: spark_connection
Survived Pclass Sex Age SibSp Parch Fare Embarked Family_Sizes
<dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 0. 1 female 2. 1. 2. 152. S 1
2 0. 1 female 25. 1. 2. 152. S 1
3 0. 1 female 50. 0. 0. 28.7 C 0
4 0. 1 male 18. 1. 0. 109. C 1
5 0. 1 male 19. 1. 0. 53.1 S 1
6 0. 1 male 19. 3. 2. 263. S 2
7 0. 1 male 22. 0. 0. 136. C 0
8 0. 1 male 24. 0. 0. 79.2 C 0
9 0. 1 male 24. 0. 1. 248. C 1
10 0. 1 male 27. 0. 2. 212. C 1
# ... with more rows
然后我只应用一个模型,例如逻辑回归。
# Create table references
train_tbl <- partition$train
test_tbl <- partition$test
# Model survival as a function of several predictors
ml_formula <- formula(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare +
Embarked + Family_Sizes)
# Train a logistic regression model
ml_log <- ml_logistic_regression(train_tbl, ml_formula)
# Create a function for scoring
score_test_data <- function(model, data=test_tbl){
pred <- sdf_predict(model, data)
select(pred, Survived, prediction)
}
# Calculate the score and AUC metric
ml_score <- score_test_data(ml_log)
现在,ml_score是:
> ml_score
# Source: lazy query [?? x 2]
# Database: spark_connection
Survived prediction
<dbl> <dbl>
1 0. 1.
2 0. 0.
3 0. 0.
4 0. 0.
5 0. 0.
6 0. 0.
7 0. 0.
8 0. 0.
9 0. 0.
10 0. 0.
# ... with more rows
现在我应用函数ml_binart_classification_eval:
ml_binary_classification_eval(ml_score,'Survived','prediction')
然后我有错误:
Error: java.lang.IllegalArgumentException: requirement failed: Column prediction must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually DoubleType.
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.util.SchemaUtils$.checkColumnType(SchemaUtils.scala:42)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:82)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at sparklyr.Invoke$.invoke(invoke.scala:102)
at sparklyr.StreamHandler$.handleMethodCall(stream.scala:97)
at sparklyr.StreamHandler$.read(stream.scala:62)
at sparklyr.BackendHandler.channelRead0(handler.scala:52)
at sparklyr.BackendHandler.channelRead0(handler.scala:14)
at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
at java.lang.Thread.run(Thread.java:748)
答案 0 :(得分:1)
在当前版本中,您需要将原始预测列名称传递给ml_binary_classification_evaluator()
。默认情况下,"rawPrediction"
文档?ml_evaluator
不正确,并且已更新。