spark管道KMeansModel clusterCenters

时间:2016-05-24 03:34:03

标签: apache-spark apache-spark-ml

我使用管道来分组文本文档。管道的最后一个阶段是ml.clustering.KMeans,它为我提供了一个包含一组聚类预测的DataFrame。我想将集群中心添加为列。我知道我可以执行Vector[] clusterCenters = kmeansModel.clusterCenters();然后将结果转换为DataFrame并将结果连接到其他DataFrame但是我希望找到一种方法来实现这一点,类似于下面的Kmeans代码:

    KMeans kMeans = new KMeans()
                .setFeaturesCol("pca")
                .setPredictionCol("kmeansclusterprediction")
                .setK(5)
                .setInitMode("random")
                .setSeed(43L)
                .setInitSteps(3)
                .setMaxIter(15);

pipeline.setStages( ...

我能够扩展KMeans并通过管道调用fit方法但是我没有任何运气扩展KMeansModel ...构造函数需要String uid和KMeansModel但我不知道如何在定义阶段并调用setStages方法时传入模型。

我也考虑扩展KMeans.scala但是作为一名Java开发人员,我只了解了一半代码,我希望在解决之前有人可能会有一个更简单的解决方案。最后,我想最终得到一个DataFrame,如下所示:

+--------------------+-----------------------+--------------------+
|               docid|kmeansclusterprediction|kmeansclustercenters|
+--------------------+-----------------------+--------------------+
|2bcbcd54-c11a-48c...|                      2|      [-0.04, -7.72]|
|0e644620-f5ff-40f...|                      3|        [0.23, 1.08]|
|665c1c2b-3065-4e8...|                      3|        [0.23, 1.08]|
|598c6268-e4b9-4c9...|                      0|      [-15.81, 0.01]|
+--------------------+-----------------------+--------------------+ 

非常感谢任何帮助或提示。 谢谢

2 个答案:

答案 0 :(得分:0)

回答我自己的问题......这实际上很简单......我扩展了KMeans和KMeansModel ......扩展的Kmeans fit方法必须返回扩展的KMeansModel。例如:

public class AnalyticsKMeansModel extends KMeansModel ...


public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans { ...

public AnalyticsKMeansModel fit(DataFrame dataset) {

    JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
        private static final long serialVersionUID = -4588981547209486909L;

        @Override
        public Vector call(Row row) throws Exception {
            Object point = row.getAs("pca");
            Vector vector = (Vector)point;
            return vector;
        }

    });

    RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
    org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
    org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
    AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
    return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
} 

一旦我更改了fit方法以返回扩展的KMeansModel类,一切都按预期工作。

答案 1 :(得分:0)

        import java.util.ArrayList;
        import java.util.Arrays;
        import java.util.List;

        import org.apache.spark.api.java.JavaRDD;
        import org.apache.spark.api.java.JavaSparkContext;
        import org.apache.spark.api.java.function.Function;
        import org.apache.spark.ml.clustering.KMeansModel;
        import org.apache.spark.mllib.linalg.Vector;
        import org.apache.spark.sql.DataFrame;
        import org.apache.spark.sql.Row;
        import org.apache.spark.sql.RowFactory;
        import org.apache.spark.sql.types.DataTypes;
        import org.apache.spark.sql.types.StructField;
        import org.apache.spark.sql.types.StructType;

        import AnalyticsCluster;

        public class AnalyticsKMeansModel extends KMeansModel {
            private static final long serialVersionUID = -8893355418042946358L;

            public AnalyticsKMeansModel(String uid, org.apache.spark.mllib.clustering.KMeansModel parentModel) {
                super(uid, parentModel);
            }

            public DataFrame transform(DataFrame dataset) {

                Vector[] clusterCenters = super.clusterCenters();

                List<AnalyticsCluster> analyticsClusters = new ArrayList<AnalyticsCluster>();

                for (int i=0; i<clusterCenters.length;i++){
                    Integer clusterId = super.predict(clusterCenters[i]);
                    Vector vector = clusterCenters[i];
                    double[] point = vector.toArray();
                    AnalyticsCluster analyticsCluster = new AnalyticsCluster(clusterId, point, 0L);
                    analyticsClusters.add(analyticsCluster);
                }

                JavaSparkContext jsc = JavaSparkContext.fromSparkContext(dataset.sqlContext().sparkContext());

                JavaRDD<AnalyticsCluster> javaRDD = jsc.parallelize(analyticsClusters);

                JavaRDD<Row> javaRDDRow = javaRDD.map(new Function<AnalyticsCluster, Row>() {
                    private static final long serialVersionUID = -2677295862916670965L;

                    @Override
                    public Row call(AnalyticsCluster cluster) throws Exception {
                        Row row = RowFactory.create(
                            String.valueOf(cluster.getID()),
                            String.valueOf(Arrays.toString(cluster.getCenter()))
                        );
                        return row;
                    }

                 });

                List<StructField> schemaColumns = new ArrayList<StructField>();
                schemaColumns.add(DataTypes.createStructField(this.getPredictionCol(), DataTypes.StringType, false));
                schemaColumns.add(DataTypes.createStructField("clusterpoint", DataTypes.StringType, false));

                StructType dataFrameSchema = DataTypes.createStructType(schemaColumns);

                DataFrame clusterPointsDF = dataset.sqlContext().createDataFrame(javaRDDRow, dataFrameSchema);

                //SOMETIMES "K" IS SET TO A VALUE GREATER THAN THE NUMBER OF ACTUAL ROWS OF DATA ... GET DISTINCT VALUES
                clusterPointsDF.registerTempTable("clusterPoints");
                DataFrame clustersDF = clusterPointsDF.sqlContext().sql("select distinct " + this.getPredictionCol()+ ", clusterpoint from clusterPoints");
                clustersDF.cache();
                clusterPointsDF.sqlContext().dropTempTable("clusterPoints");

                DataFrame transformedDF = super.transform(dataset);
                transformedDF.cache();

                DataFrame df = transformedDF.join(clustersDF,
                        transformedDF.col(this.getPredictionCol()).equalTo(clustersDF.col(this.getPredictionCol())), "inner")
                            .drop(clustersDF.col(this.getPredictionCol()));

                return df;
            }
        }





    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.ml.param.Param;
    import org.apache.spark.ml.param.Params;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.rdd.RDD;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;

    import scala.runtime.BoxesRunTime;

    public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans {
        private static final long serialVersionUID = 8943702485821267996L;
        private static String uid = null;

        public AnalyticsKMeans(String uid){
            AnalyticsKMeans.uid= uid;
        }


        public AnalyticsKMeansModel fit(DataFrame dataset) {

            JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
                private static final long serialVersionUID = -4588981547209486909L;

                @Override
                public Vector call(Row row) throws Exception {
                    Object point = row.getAs("pca");
                    Vector vector = (Vector)point;
                    return vector;
                }

            });

            RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
            org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
            org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
            AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
            return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
        }

    }




import java.io.Serializable;
import java.util.Arrays;

public class AnalyticsCluster implements Serializable {
    private static final long serialVersionUID = 6535671221958712594L;

    private final int id;
    private volatile double[] center;
    private volatile long count;

    public AnalyticsCluster(int id, double[] center, long initialCount) {
    //      Preconditions.checkArgument(center.length > 0);
    //      Preconditions.checkArgument(initialCount >= 1);
        this.id = id;
        this.center = center;
        this.count = initialCount;
    }

    public int getID() {
        return id;
    }

    public double[] getCenter() {
        return center;
    }

    public long getCount() {
        return count;
    }

    public synchronized void update(double[] newPoint, long newCount) {
        int length = center.length;
    //      Preconditions.checkArgument(length == newPoint.length);
        double[] newCenter = new double[length];
        long newTotalCount = newCount + count;
        double newToTotal = (double) newCount / newTotalCount;
        for (int i = 0; i < length; i++) {
          double centerI = center[i];
          newCenter[i] = centerI + newToTotal * (newPoint[i] - centerI);
        }
        center = newCenter;
        count = newTotalCount;
    }

    @Override
    public synchronized String toString() {
        return id + " " + Arrays.toString(center) + " " + count;
    }

//  public static void main(String[] args) {
//      double[] point = new double[2];
//      point[0] = 0.10150532938119154;
//      point[1] = -0.23734759238651829;
//      
//      Cluster cluster = new Cluster(1,point, 10L);
//      System.out.println("cluster: " + cluster.toString());
//  }

}