我正在尝试扩展或代理org.apache.spark.ml.clustering.KMeans类,以便授权K = 1。
class K1Means extends Estimator{
final val kmeans = new KMeans()
val k = 1
override def setK(value:Int) {
if(value >1){
this.kmeans.setK(value)
}
}
override def fit(dataset: DataFrame): KMeansModel = {
if(this.k == 1){
/* super specific to my case */
val avg_sample = Vectors.zeros(
dataset
.select("scaledFeatures")
.take(1)(0)(0) // first row
.asInstanceOf[DenseVector] // was of type Any
.size
) // with the scaling the average value of each column is 0
var centers_local = Array(avg_sample)
return new KMeansModel(centers_local)
}
else{
return this.kmeans.fit(dataset)
}
}
// every method then calls this.kmeans.method()
}
我试过这个,但new KMeansModel(centers_local)
未获得授权,因为KMeansModel有一个私有构造函数。
以下是错误消息:
constructor KMeansModel in class KMeansModel cannot be accessed in class K1Means
我还试图扩展KMeansModel,所以我可以自己创建并返回它:
class K1MeansModel(centers: Array[DenseVector]) extends KMeansModel{}
但它也失败了:constructor KMeansModel in class KMeansModel cannot be accessed in class K1MeansModel
答案 0 :(得分:4)
这里有几个问题,从KMeansModel开始是私有的: https://github.com/apache/spark/blob/4f83ca1059a3b580fca3f006974ff5ac4d5212a1/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala#L102
为什么这是一个问题?您可以按照建议的方式完全编写自己的代理,但是为了覆盖“fit”方法,该函数返回的数据类型需要是KMeansModel或兼容(比如说“K1MeansModel”),如下所示: / p>
class K1MeansModel extends KMeansModel{
// ...
}
class K1Means extends KMeans{
final val kmeans = new KMeans()
// ...
override def fit(dataset: DataFrame): KMeansModel = {
if(this.k == 1){
// ...
return new K1MeansModel(centers_local)
}
else{
return this.kmeans.fit(dataset)
}
}
}
但是,因为 KMeansModel 是私密的,这是不可能的。所以你可能会想“为什么不重新实现呢?”。的确,你可以复制&从GitHub粘贴 KMeansModel 的完整代码。
KMeansModel的定义如下:
class KMeansModel (
override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams { }
但是,因为 KMeansParams 是私密的,这是不可能的。所以你可能会想“为什么不重新实现呢?”。的确,你可以复制&从GitHub粘贴 KMeansParams 的全部代码。
KMeansParams的定义如下:
trait K1MeansParams
extends Params
with HasMaxIter
with HasFeaturesCol
with HasSeed
with HasPredictionCol
with HasTol { }
但是,因为 HasMaxIter,HasFeaturesCol,HasSeed,HasPredictionCol,HasTol 都是私有的,这是不可能的。 ......你明白了。
TL; DR 是的,您可以重新实现(复制和粘贴)大量的火花类到您的项目中,只是为了覆盖KMeans。我计算至少7个需要复制和粘贴的课程。对我来说,感觉很糟糕。 相反,我建议将代码直接添加到Apache Spark。分叉Spark GitHub repo,将您的K = 1代码直接添加到ml.KMeans类中,然后完成。