使用逻辑回归进行分类

时间:2017-04-27 17:06:24

标签: apache-spark machine-learning

我有以下代码,并尝试使用字符串索引器和使用向量汇编程序

的功能设置标签
StructType schema = createStructType(new StructField[]{
                  createStructField("id", IntegerType, false),
                  createStructField("country", StringType, false),
                  createStructField("hour", IntegerType, false),
                  createStructField("clicked", DoubleType, false)
                });

                List<Row> data = Arrays.asList(
                  RowFactory.create(7, "US", 18, 1.0),
                  RowFactory.create(8, "CA", 12, 0.0),
                  RowFactory.create(9, "NZ", 15, 0.0)
                );

                Dataset<Row> dataset = sparkSession.createDataFrame(data, schema);

                StringIndexer indexer = new StringIndexer()
                          .setInputCol("clicked")
                          .setOutputCol("label");
                Dataset<Row> ds = indexer.fit(dataset).transform(dataset);
                VectorAssembler assembler = new VectorAssembler()
                          .setInputCols(new String[]{"id", "country", "hour"})
                          .setOutputCol("features");
                Dataset<Row> finalDS = assembler.transform(ds);

                LogisticRegression lr = new LogisticRegression()
                          .setMaxIter(10)
                          .setRegParam(0.3)
                          .setElasticNetParam(0.8);

                        // Fit the model
                        LogisticRegressionModel lrModel = lr.fit(finalDS);
                        Dataset<Row> output = lrModel.transform(finalDS);
                        output.select("features", "label").show();

当我在spark上提交它时,我收到以下错误消息:

 7/04/27 22:34:24 INFO DAGScheduler: Job 0 finished: countByValue at StringIndexer.scala:92, took 1.003742 s
Exception in thread "main" java.lang.IllegalArgumentException: Data type StringType is not supported.
    at org.apache.spark.ml.feature.VectorAssembler$$anonfun$transformSchema$1.apply(VectorAssembler.scala:121)
    at org.apache.spark.ml.feature.VectorAssembler$$anonfun$transformSchema$1.apply(VectorAssembler.scala:117)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
    at org.apache.spark.ml.feature.VectorAssembler.transformSchema(VectorAssembler.scala:117)
    at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
    at org.apache.spark.ml.feature.VectorAssembler.transform(VectorAssembler.scala:54)

1 个答案:

答案 0 :(得分:0)

  

VectorAssembler只接受三种类型的列:

DoubleType - 双标量,可选地包含列元数据。

NumericType - 任意数字。

VectorUDT - 矢量列。

了解更多 - &gt;

  1. Formatting data for spark ML
  2. How to create correct data frame for classification in Spark ML