如何在scala中设置逻辑回归的数据?

时间:2017-07-07 01:04:14

标签: scala apache-spark logistic-regression

我是scala的新手,我想实现一个逻辑回归模型。所以我最初加载一个csv文件如下:

val sqlContext = new org.apache.spark.sql.SQLContext(sc)
 val df = sqlContext.read.format("com.databricks.spark.csv")
    .option("header", "true")
    .option("inferSchema", "true")
    .load("D:/sample.txt")

文件如下:

P,P,A,A,A,P,NB
N,N,A,A,A,N,NB
A,A,A,A,A,A,NB
P,P,P,P,P,P,NB
N,N,P,P,P,N,NB
A,A,P,P,P,A,NB
P,P,A,P,P,P,NB
P,P,P,A,A,P,NB
P,P,A,P,A,P,NB
P,P,A,A,P,P,NB
P,P,P,P,A,P,NB
P,P,P,A,P,P,NB
N,N,A,P,P,N,NB
N,N,P,A,A,N,NB
N,N,A,P,A,N,NB
N,N,A,P,A,N,NB
N,N,A,A,P,N,NB
N,N,P,P,A,N,NB
N,N,P,A,P,N,NB
A,A,A,P,P,A,NB
A,A,P,A,A,A,NB
A,A,A,P,A,A,NB
A,A,A,A,P,A,NB
A,A,P,P,A,A,NB
A,A,P,A,P,A,NB
P,N,A,A,A,P,NB
N,P,A,A,A,N,NB
P,N,A,A,A,N,NB
P,N,P,P,P,P,NB
N,P,P,P,P,N,NB

然后我想通过以下代码训练模型:

val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
      .setFeaturesCol("Feature")
      .setLabelCol("Label")

然后我按照下面的方式拟合模型:

 val lrModel = lr.fit(df)

println(lrModel.coefficients +"are the coefficients")
println(lrModel.interceptVector+"are the intercerpt vactor")
println(lrModel.summary +"is summary")

但它不会打印结果。

感谢任何帮助。

1 个答案:

答案 0 :(得分:1)

来自您的代码:

val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
      .setFeaturesCol("Feature")  <- here
      .setLabelCol("Label") <- here

您正在设置features列和label列。由于您没有提及列名称,我假设包含NB值的列是您的标签,并且您希望包含所有其他列是用于预测的列。

您希望包含在模型中的所有预测变量都需要采用单向量列的形式,通常称为features列。您需要使用VectorAssembler创建它,如下所示:

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors

//creating features column
val assembler = new VectorAssembler()
  .setInputCols(Array(" insert your column names here "))
  .setOutputCol("Feature")

参考:https://spark.apache.org/docs/latest/ml-features.html#vectorassembler

现在您可以继续使用逻辑回归模型。 pipeline用于在fitting数据之前合并多个转换。

val pipeline = new Pipeline().setStages(Array(assembler,lr))

//fitting the model
val lrModel = pipeline.fit(df)