如何使用私有构造函数扩展(或代理)scala类

时间:2016-06-21 09:48:34

标签: scala inheritance private proxy-classes

我正在尝试扩展或代理org.apache.spark.ml.clustering.KMeans类,以便授权K = 1。

class K1Means extends Estimator{

    final val kmeans = new KMeans()
    val k = 1

    override def setK(value:Int) {
        if(value >1){
            this.kmeans.setK(value)
        }
    }

    override def fit(dataset: DataFrame): KMeansModel = { 
        if(this.k == 1){
            /* super specific to my case */
            val avg_sample = Vectors.zeros(
                dataset
                .select("scaledFeatures")
                .take(1)(0)(0)  // first row
                .asInstanceOf[DenseVector]  // was of type Any
                .size
            ) // with the scaling the average value of each column is 0
            var centers_local = Array(avg_sample)
            return new KMeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
// every method then calls this.kmeans.method()
}

我试过这个,但new KMeansModel(centers_local)未获得授权,因为KMeansModel有一个私有构造函数。 以下是错误消息:

constructor KMeansModel in class KMeansModel cannot be accessed in class K1Means

我还试图扩展KMeansModel,所以我可以自己创建并返回它:

class K1MeansModel(centers: Array[DenseVector]) extends KMeansModel{}

但它也失败了:constructor KMeansModel in class KMeansModel cannot be accessed in class K1MeansModel

1 个答案:

答案 0 :(得分:4)

这里有几个问题,从KMeansModel开始是私有的: https://github.com/apache/spark/blob/4f83ca1059a3b580fca3f006974ff5ac4d5212a1/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala#L102

为什么这是一个问题?您可以按照建议的方式完全编写自己的代理,但是为了覆盖“fit”方法,该函数返回的数据类型需要是KMeansModel或兼容(比如说“K1MeansModel”),如下所示: / p>

class K1MeansModel extends KMeansModel{
    // ...
}

class K1Means extends KMeans{

    final val kmeans = new KMeans()
    // ...

    override def fit(dataset: DataFrame): KMeansModel = { 
        if(this.k == 1){
            // ...
            return new K1MeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
}

但是,因为 KMeansModel 是私密的,这是不可能的。所以你可能会想“为什么不重新实现呢?”。的确,你可以复制&从GitHub粘贴 KMeansModel 的完整代码。

KMeansModel的定义如下:

class KMeansModel (
        override val uid: String, 
        private val parentModel: MLlibKMeansModel) 
    extends Model[KMeansModel] with KMeansParams { }

但是,因为 KMeansParams 是私密的,这是不可能的。所以你可能会想“为什么不重新实现呢?”。的确,你可以复制&从GitHub粘贴 KMeansParams 的全部代码。

KMeansParams的定义如下:

trait K1MeansParams 
    extends Params 
        with HasMaxIter 
        with HasFeaturesCol 
        with HasSeed 
        with HasPredictionCol 
        with HasTol { }

但是,因为 HasMaxIter,HasFeaturesCol,HasSeed,HasPredictionCol,HasTol 都是私有的,这是不可能的。 ......你明白了。

TL; DR 是的,您可以重新实现(复制和粘贴)大量的火花类到您的项目中,只是为了覆盖KMeans。我计算至少7个需要复制和粘贴的课程。对我来说,感觉很糟糕。 相反,我建议将代码直接添加到Apache Spark。分叉Spark GitHub repo,将您的K = 1代码直接添加到ml.KMeans类中,然后完成。