我使用管道来分组文本文档。管道的最后一个阶段是ml.clustering.KMeans,它为我提供了一个包含一组聚类预测的DataFrame。我想将集群中心添加为列。我知道我可以执行Vector[] clusterCenters = kmeansModel.clusterCenters();
然后将结果转换为DataFrame并将结果连接到其他DataFrame但是我希望找到一种方法来实现这一点,类似于下面的Kmeans代码:
KMeans kMeans = new KMeans()
.setFeaturesCol("pca")
.setPredictionCol("kmeansclusterprediction")
.setK(5)
.setInitMode("random")
.setSeed(43L)
.setInitSteps(3)
.setMaxIter(15);
pipeline.setStages( ...
我能够扩展KMeans并通过管道调用fit方法但是我没有任何运气扩展KMeansModel ...构造函数需要String uid和KMeansModel但我不知道如何在定义阶段并调用setStages方法时传入模型。
我也考虑扩展KMeans.scala但是作为一名Java开发人员,我只了解了一半代码,我希望在解决之前有人可能会有一个更简单的解决方案。最后,我想最终得到一个DataFrame,如下所示:
+--------------------+-----------------------+--------------------+
| docid|kmeansclusterprediction|kmeansclustercenters|
+--------------------+-----------------------+--------------------+
|2bcbcd54-c11a-48c...| 2| [-0.04, -7.72]|
|0e644620-f5ff-40f...| 3| [0.23, 1.08]|
|665c1c2b-3065-4e8...| 3| [0.23, 1.08]|
|598c6268-e4b9-4c9...| 0| [-15.81, 0.01]|
+--------------------+-----------------------+--------------------+
非常感谢任何帮助或提示。 谢谢
答案 0 :(得分:0)
回答我自己的问题......这实际上很简单......我扩展了KMeans和KMeansModel ......扩展的Kmeans fit方法必须返回扩展的KMeansModel。例如:
public class AnalyticsKMeansModel extends KMeansModel ...
public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans { ...
public AnalyticsKMeansModel fit(DataFrame dataset) {
JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
private static final long serialVersionUID = -4588981547209486909L;
@Override
public Vector call(Row row) throws Exception {
Object point = row.getAs("pca");
Vector vector = (Vector)point;
return vector;
}
});
RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
}
一旦我更改了fit方法以返回扩展的KMeansModel类,一切都按预期工作。
答案 1 :(得分:0)
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import AnalyticsCluster;
public class AnalyticsKMeansModel extends KMeansModel {
private static final long serialVersionUID = -8893355418042946358L;
public AnalyticsKMeansModel(String uid, org.apache.spark.mllib.clustering.KMeansModel parentModel) {
super(uid, parentModel);
}
public DataFrame transform(DataFrame dataset) {
Vector[] clusterCenters = super.clusterCenters();
List<AnalyticsCluster> analyticsClusters = new ArrayList<AnalyticsCluster>();
for (int i=0; i<clusterCenters.length;i++){
Integer clusterId = super.predict(clusterCenters[i]);
Vector vector = clusterCenters[i];
double[] point = vector.toArray();
AnalyticsCluster analyticsCluster = new AnalyticsCluster(clusterId, point, 0L);
analyticsClusters.add(analyticsCluster);
}
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(dataset.sqlContext().sparkContext());
JavaRDD<AnalyticsCluster> javaRDD = jsc.parallelize(analyticsClusters);
JavaRDD<Row> javaRDDRow = javaRDD.map(new Function<AnalyticsCluster, Row>() {
private static final long serialVersionUID = -2677295862916670965L;
@Override
public Row call(AnalyticsCluster cluster) throws Exception {
Row row = RowFactory.create(
String.valueOf(cluster.getID()),
String.valueOf(Arrays.toString(cluster.getCenter()))
);
return row;
}
});
List<StructField> schemaColumns = new ArrayList<StructField>();
schemaColumns.add(DataTypes.createStructField(this.getPredictionCol(), DataTypes.StringType, false));
schemaColumns.add(DataTypes.createStructField("clusterpoint", DataTypes.StringType, false));
StructType dataFrameSchema = DataTypes.createStructType(schemaColumns);
DataFrame clusterPointsDF = dataset.sqlContext().createDataFrame(javaRDDRow, dataFrameSchema);
//SOMETIMES "K" IS SET TO A VALUE GREATER THAN THE NUMBER OF ACTUAL ROWS OF DATA ... GET DISTINCT VALUES
clusterPointsDF.registerTempTable("clusterPoints");
DataFrame clustersDF = clusterPointsDF.sqlContext().sql("select distinct " + this.getPredictionCol()+ ", clusterpoint from clusterPoints");
clustersDF.cache();
clusterPointsDF.sqlContext().dropTempTable("clusterPoints");
DataFrame transformedDF = super.transform(dataset);
transformedDF.cache();
DataFrame df = transformedDF.join(clustersDF,
transformedDF.col(this.getPredictionCol()).equalTo(clustersDF.col(this.getPredictionCol())), "inner")
.drop(clustersDF.col(this.getPredictionCol()));
return df;
}
}
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.Params;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import scala.runtime.BoxesRunTime;
public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans {
private static final long serialVersionUID = 8943702485821267996L;
private static String uid = null;
public AnalyticsKMeans(String uid){
AnalyticsKMeans.uid= uid;
}
public AnalyticsKMeansModel fit(DataFrame dataset) {
JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
private static final long serialVersionUID = -4588981547209486909L;
@Override
public Vector call(Row row) throws Exception {
Object point = row.getAs("pca");
Vector vector = (Vector)point;
return vector;
}
});
RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
}
}
import java.io.Serializable;
import java.util.Arrays;
public class AnalyticsCluster implements Serializable {
private static final long serialVersionUID = 6535671221958712594L;
private final int id;
private volatile double[] center;
private volatile long count;
public AnalyticsCluster(int id, double[] center, long initialCount) {
// Preconditions.checkArgument(center.length > 0);
// Preconditions.checkArgument(initialCount >= 1);
this.id = id;
this.center = center;
this.count = initialCount;
}
public int getID() {
return id;
}
public double[] getCenter() {
return center;
}
public long getCount() {
return count;
}
public synchronized void update(double[] newPoint, long newCount) {
int length = center.length;
// Preconditions.checkArgument(length == newPoint.length);
double[] newCenter = new double[length];
long newTotalCount = newCount + count;
double newToTotal = (double) newCount / newTotalCount;
for (int i = 0; i < length; i++) {
double centerI = center[i];
newCenter[i] = centerI + newToTotal * (newPoint[i] - centerI);
}
center = newCenter;
count = newTotalCount;
}
@Override
public synchronized String toString() {
return id + " " + Arrays.toString(center) + " " + count;
}
// public static void main(String[] args) {
// double[] point = new double[2];
// point[0] = 0.10150532938119154;
// point[1] = -0.23734759238651829;
//
// Cluster cluster = new Cluster(1,point, 10L);
// System.out.println("cluster: " + cluster.toString());
// }
}