org.apache.spark.ml.classification.LogisticRegression fit()的输入格式是什么?

时间:2016-08-01 12:12:40

标签: scala apache-spark

在培训LogisticRegression模型的this示例中,他们将RDD [LabeledPoint]用作fit()方法的输入,但是他们写了“//我们使用LabeledPoint,这是一个案例类.Spark SQL可以转换案例类的RDD //进入SchemaRDDs,它使用案例类元数据来推断架构。“

这种转换发生在哪里?当我尝试这段代码时:

val sqlContext = new SQLContext(sc)
import sqlContext._
val model = lr.fit(training);

,其中training是RDD [LabeledPoint]类型的训练,它给出了一个编译错误,指出fit需要一个数据帧。当我将RDD转换为数据帧时,我得到了这个例外:

An exception occured while executing the Java class. null: InvocationTargetException: requirement failed: Column features must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually StructType(StructField(label,DoubleType,false), StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

但这对我来说很困惑。为什么会期待Vector?它还需要标签。所以我想知道什么是正确的格式?

我使用ML LogisticRegression而不是Mllib LogisticRegressionWithLBFGS的原因是因为我想要一个elasticNet实现。

1 个答案:

答案 0 :(得分:3)

Exception表示DataFrame需要以下结构:

StructType(StructField(label,DoubleType,false), 
StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

所以从这样的(标签,特征)元组列表中准备训练数据:

val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")