Spark 2.1.0 - SparkML要求失败

时间:2017-07-29 17:51:51

标签: java apache-spark k-means apache-spark-ml

我正在玩Spark 2.1.0 Kmeans - 聚类算法。

public class ClusteringTest {
    public static void main(String[] args) {
        SparkSession session = SparkSession.builder()
                .appName("Clustering Test")
                .config("spark.master", "local")
                .getOrCreate();
        session.sparkContext().setLogLevel("ERROR");

        List<Row> rawDataTraining = Arrays.asList(
                RowFactory.create(1.0,Vectors.dense( 1.0, 1.0, 1.0).toSparse()),
                RowFactory.create(1.0,Vectors.dense(2.0, 2.0, 2.0).toSparse()),
                RowFactory.create(1.0,Vectors.dense(3.0, 3.0, 3.0).toSparse()),

                RowFactory.create(2.0,Vectors.dense(6.0, 6.0, 6.0).toSparse()),
                RowFactory.create(2.0,Vectors.dense(7.0, 7.0, 7.0).toSparse()),
                RowFactory.create(2.0,Vectors.dense(8.0, 8.0,8.0).toSparse()),
//...
        StructType schema = new StructType(new StructField[]{

                new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("features", new VectorUDT(), false, Metadata.empty())
        });

        Dataset<Row> myRawData = session.createDataFrame(rawDataTraining, schema);
        Dataset<Row>[] splits = myRawData.randomSplit(new double[]{0.75, 0.25});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testData = splits[1];

        //Train Kmeans
        KMeans kMeans = new KMeans().setK(3).setSeed(100);
        KMeansModel kMeansModel = kMeans.fit(trainingData);
        Dataset<Row> predictions = kMeansModel.transform(testData);
        predictions.show(false);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("label")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("accuracy" + accuracy);
    }
}

控制台输出是:

+-----+----------------------------+----------+
|label|features                    |prediction|
+-----+----------------------------+----------+
|2.0  |(3,[0,1,2],[7.0,7.0,7.0])   |2         |
|3.0  |(3,[0,1,2],[11.0,11.0,11.0])|2         |
|3.0  |(3,[0,1,2],[12.0,12.0,12.0])|1         |
|3.0  |(3,[0,1,2],[13.0,13.0,13.0])|1         |
+-----+----------------------------+----------+

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Column prediction must be of type DoubleType but was actually IntegerType.
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.ml.util.SchemaUtils$.checkColumnType(SchemaUtils.scala:42)
    at org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate(MulticlassClassificationEvaluator.scala:75)
    at ClusteringTest.main(ClusteringTest.java:84)

Process finished with exit code 1

正如您所见,预测结果是整数。但是要使用MulticlassClassificationEvalutor,我需要将这些预测结果转换为Double。我该怎么办?

1 个答案:

答案 0 :(得分:1)

TL; DR 这不是可行的方法。

KMeans是无监督的方法,您获得的群集标识符是任意的(群集的ID可以被置换)并且与label列无关。因此,使用MulticlassClassificationEvaluator比较KMeans的现有标签和输出没有任何意义。

你应该使用一些有监督的分类器,如多项逻辑回归或朴素贝叶斯。

如果您想坚持KMeans,请使用适当的质量指标,例如computeCost返回的指标,但这将完全忽略标签信息。