为什么paramGridBuilder scala错误与CountVectorizer?

时间:2018-05-17 05:59:06

标签: scala apache-spark apache-spark-ml

我有一个关于使用paramGrid for Kfold的CountVectorizer的问题。 但我不知道问题是什么因为错误给了我同样的类型与recomendation类型

这是错误

<console>:57: error: missing argument list for method setMinTF in class CountVectorizer
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `setMinTF _` or `setMinTF(_)` instead of `setMinTF`.
        addGrid(countVectorizer.setMinTF, Array(1,3,5,7,9)).
                                ^
<console>:56: error: not found: value paramGrid
                setEstimatorParamMaps(paramGrid).

这是我的代码

val countVectorizer = new CountVectorizer().setInputCol("subject").setOutputCol("features")
val paramGrid = new ParamGridBuilder().
    addGrid(countVectorizer.setMinTF, Array(1,3,5,7,9)).
    addGrid(logisticRegression.regParam, Array(0.1, 0.01)).
    build()

感谢您的帮助

更新 - 更多代码并更改为countVectorizer.minTF

但仍有错误

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.sql.SparkSession
import org.apache.log4j._
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer, StringIndexer,CountVectorizer, CountVectorizerModel,Word2Vec,OneHotEncoder}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Row
import org.apache.spark.ml.Pipeline
import org.apache.spark.mllib.evaluation.MulticlassMetrics



Logger.getLogger("org").setLevel(Level.ERROR)
val spark = SparkSession.builder().getOrCreate()
val data = spark.read.option("header","true").
            option("inferSchema","true").
            option("delimiter","\t").
            format("csv").
            load("datasetId.tsv")
            //withColumn("subject", split($"subject", " "))

val logRegDataAll = data.select(data("labels").as("labelss"),$"subject".as("subjects"))
val logRegData = logRegDataAll.na.drop()


val Array(training,test) = logRegData.randomSplit(Array(0.7,0.3),seed=1)

// Word2Vec
// val word2Vec = new Word2Vec().setInputCol("subject").
//                          setOutputCol("features").
//                          setVectorSize(100)

val tokenizer = new Tokenizer().
                    setInputCol("subjects").
                    setOutputCol("subject")

// TF-IDF
// val hashingTF = new HashingTF().
//              setInputCol("subject").
//              setOutputCol("rawFeatures")
// val idf = new IDF().
//              setInputCol("rawFeatures").
//              setOutputCol("features")

//CountVectorizer / TF
val countVectorizer = new CountVectorizer().
                        setInputCol("subject").
                        setOutputCol("features")

// convert string into numerical values
val labelIndexer = new StringIndexer().
                        setInputCol("labelss").
                        setOutputCol("label")

// convert numerical to one hot encoder
// val labelEncoder = new OneHotEncoder().
//                    setInputCol("labelsss").
//                    setOutputCol("label")

val logisticRegression = new LogisticRegression()

//val pipeline = new Pipeline().setStages(Array(tokenizer,word2Vec,labelIndexer,logisticRegression))
val pipeline = new Pipeline().setStages(Array(tokenizer,countVectorizer,labelIndexer,logisticRegression))
//val pipeline = new Pipeline().setStages(Array(tokenizer,hashingTF,idf,labelIndexer,logisticRegression))


// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
val paramGrid = new ParamGridBuilder().
    //addGrid(hashingTF.numFeatures, Array(8000,10000,15000)).
    //addGrid(word2Vec.windowSize, Array(1,2,3)).
    addGrid(countVectorizer.minTF, Array(1.0,3.0,5.0)).  // still not work
    addGrid(logisticRegression.regParam, Array(0.1, 0.01)).
    build()

// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
// is areaUnderROC.
val cv = new CrossValidator().
    setEstimator(pipeline).
    setEvaluator(new MulticlassClassificationEvaluator).
    setEstimatorParamMaps(paramGrid).
    setNumFolds(10).  // Use 3+ in practice
    setParallelism(2).  // Evaluate up to 2 parameter settings in parallel
    setSeed(123) // random seed

// Run cross-validation, and choose the best set of parameters.
//val model = pipeline.fit(training)
val model = cv.fit(training)
val result = model.transform(test)

以及带制表符分隔符和.tsv文件的数据集

labels  subject
CATEGORY_SOCIAL 8 popular Pins for you
CATEGORY_PROMOTIONS Want to plan with Jira and design in UXPin?

如果我们使用countVectorizer.minTF,Array(1.0,3.0,5.0))

给我一​​个像这样的错误

found   : org.apache.spark.ml.param.DoubleParam
required: org.apache.spark.ml.param.Param[AnyVal]

1 个答案:

答案 0 :(得分:2)

首先,您需要使用参数ParamGridBuilder而非 setters

其次,您的参数需要作为 double 传递。

所以你会有类似的东西:

import org.apache.spark.ml.feature.CountVectorizer
import org.apache.spark.ml.tuning.ParamGridBuilder

val countVectorizer = new CountVectorizer().setInputCol("subject").setOutputCol("features")
val paramGrid = new ParamGridBuilder().addGrid(countVectorizer.minTF, Array(1.0,3.0,5.0,7.0,9.0)).build()
// paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
// Array({
//  cntVec_4eab680c176c-minTF: 1.0
// }, {
//  cntVec_4eab680c176c-minTF: 3.0
// }, {
//  cntVec_4eab680c176c-minTF: 5.0
// }, {
//  cntVec_4eab680c176c-minTF: 7.0
// }, {
//  cntVec_4eab680c176c-minTF: 9.0
// })

修改

我无法重现您的错误,但我发现了其他错误。我已经在代码中用解决方案对它们进行了评论。

// organize imports
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{CountVectorizer, StringIndexer, Tokenizer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}

// Create a SparkSession if needed.
val spark = SparkSession.builder().getOrCreate()

// import implicits
import spark.implicits._

// I have created some toy data. 
val data: DataFrame = Seq(
  ("CATEGORY_SOCIAL", "8 popular Pins for you"),
  ("CATEGORY_PROMOTIONS", "Want to plan with Jira and design in UXPin?"),
  ("CATEGORY_PROMOTIONS", "Test our new service today"),
  ("CATEGORY_PROMOTIONS", "deliveries on sunday"),
  ("CATEGORY_SOCIAL", "Twitter - your friends are missing you")
).toDF("labelss", "subjects")

// The tokenizer is ok even thought columns name wise, it can get confusing
val tokenizer: Tokenizer = new Tokenizer().
  setInputCol("subjects").
  setOutputCol("subject")

// Since we are creating a PipelineModel, it's always better 
// to use the column from the previous stage 
val countVectorizer: CountVectorizer = new CountVectorizer().
  setInputCol(tokenizer.getOutputCol).
  setOutputCol("features")

val labelIndexer: StringIndexer = new StringIndexer().
  setInputCol("labelss").
  setOutputCol("labelsss")

// Same comment here 
val logisticRegression: LogisticRegression = new LogisticRegression().setLabelCol(labelIndexer.getOutputCol)

val pipeline: Pipeline = new Pipeline().setStages(Array(tokenizer, countVectorizer, labelIndexer, logisticRegression))

val paramGrid: Array[ParamMap] = new ParamGridBuilder().
  addGrid(countVectorizer.minTF, Array(1.0, 3.0, 5.0)). 
  addGrid(logisticRegression.regParam, Array(0.1, 0.01)).
  build()
// This works well. Result :
//     paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
// Array({
//  cntVec_de795141d282-minTF: 1.0,
//  logreg_fe22d7731a7e-regParam: 0.1
// }, {
//  cntVec_de795141d282-minTF: 3.0,
//  logreg_fe22d7731a7e-regParam: 0.1
// }, {
//  cntVec_de795141d282-minTF: 5.0,
//  logreg_fe22d7731a7e-regParam: 0.1
// }, {
//  cntVec_de795141d282-minTF: 1.0,
//  logreg_fe22d7731a7e-regParam: 0.01
// }, {
//  cntVec_de795141d282-minTF: 3.0,
//  logreg_fe22d7731a7e-regParam: 0.01
// }, {
//  cntVec_de795141d282-minTF: 5.0,
//  logreg_fe22d7731a7e-regParam: 0.01
// })

// Here is the trick, if you don't set your evaluator 
// with the label you need to use explicitly, you'll end up 
// getting an error since your are not using the default 
// label column name value
// Something like : Caused by: java.lang.IllegalArgumentException: Field "label" does not exist.
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelIndexer.getOutputCol)
// evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_c9d72a485d1d

val cv: CrossValidator = new CrossValidator().
  setEstimator(pipeline).
  setEvaluator(evaluator).
  setEstimatorParamMaps(paramGrid).
  setNumFolds(10). // Use 3+ in practice
  setParallelism(2). // Evaluate up to 2 parameter settings in parallel
  setSeed(123) // random seed
// cv: org.apache.spark.ml.tuning.CrossValidator = cv_2e1c55435a49

val model: CrossValidatorModel = cv.fit(data)
// model: org.apache.spark.ml.tuning.CrossValidatorModel = cv_2e1c55435a49

val result: DataFrame = model.transform(data)
// result: org.apache.spark.sql.DataFrame = [labelss: string, subjects: string ... 6 more fields]

result.show
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+
// |            labelss|            subjects|             subject|            features|labelsss|       rawPrediction|         probability|prediction|
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+
// |    CATEGORY_SOCIAL|8 popular Pins fo...|[8, popular, pins...|(28,[0,8,16,21,25...|     1.0|[-2.5645425270090...|[0.07145555978623...|       1.0|
// |CATEGORY_PROMOTIONS|Want to plan with...|[want, to, plan, ...|(28,[1,6,9,17,18,...|     0.0|[3.57523120417979...|[0.97275417761670...|       0.0|
// |CATEGORY_PROMOTIONS|Test our new serv...|[test, our, new, ...|(28,[3,4,10,12,20...|     0.0|[3.15934297459226...|[0.95927528667918...|       0.0|
// |CATEGORY_PROMOTIONS|deliveries on sunday|[deliveries, on, ...|(28,[5,22,26],[1....|     0.0|[2.81641463947790...|[0.94355642175747...|       0.0|
// |    CATEGORY_SOCIAL|Twitter - your fr...|[twitter, -, your...|(28,[0,2,7,11,13,...|     1.0|[-2.8838332277996...|[0.05295855512212...|       1.0|
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+

注意:我出于实际原因没有分割我的数据,我没有足够的数据可以拆分。