至少常见的DecisionTreeRegressor和RandomForestRegressor上层

时间:2016-11-14 19:06:25

标签: scala apache-spark-mllib type-parameter

我想创建一个返回以下两种类型之一的方法:

 - org.apache.spark.ml.regression.DecisionTreeRegressor
 - org.apache.spark.ml.regression.RandomForestRegressor

此方法的返回类型是什么?我相信它会有一个返回类型

 - org.apache.spark.ml.Estimator<M>

但我不知道类型参数M应该是什么。

如果我只是这样做(例如):

  def getRegressor(): org.apache.spark.ml.Estimator = {
    new DecisionTreeRegressor()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setMaxBins(100)
  }

我收到以下错误:

  

class Estimator采用类型参数

1 个答案:

答案 0 :(得分:2)

使用_(下划线)忽略该类型如果您不关心它。

def getRegressor(): org.apache.spark.ml.Estimator[_] = {
    new DecisionTreeRegressor()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setMaxBins(100)
  }

我认为类型应为DecisionTreeRegressionModel

 def getRegressor(): org.apache.spark.ml.Estimator[DecisionTreeRegressionModel] = {
    new DecisionTreeRegressor()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setMaxBins(100)
  }