来自ML Tuning
https://spark.apache.org/docs/latest/ml-tuning.html的Spark文档的以下片段显然为numFeatures
设置了Hashing TermFrequency
,为{{regParam
设置了LogisticRegression
(正则化) 1}}模型:
HashingTF
和LogisticRegression
:
val hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)
CrossValidator
的 ParamGridBuilder :
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
val paramGrid = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
CrossValidator
“如何知道”如何将网格值应用于各个实体?我想看看它是否是通过反思但是不清楚。
`CrossValidator可能设置的方法是:
HashingTF :
/** @group setParam */
@Since("1.2.0")
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
逻辑回归:
class LogisticRegressionModel @Since("1.3.0") (
..
@Since("1.3.0") val numFeatures: Int,
这是CrossValidator
上的调用:
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // Use 3+ in practice
我无法确定setEstimatorParamMaps
如何正确设置HashingTF
和LogisticRegression
值。 (注意 工作!)
这个问题的原因是我想添加一个新的Evaluator
,我不确定如何将其与CrossValidator
功能相匹配。
一个具体的例子:对于LDAModel
:我们有调整参数k
,vocabSize
和docConcentration
:如何为这些参数设置ParamGrid
?
答案 0 :(得分:1)
一个具体的例子:对于LDAModel:我们有调整参数k,vocabSize和docConcentration:如何为那些设置ParamGrid?
addGrid
会获得Param
和Array
个兼容值。通常,它设置在Estimator
(LDA
)而不是Transformer (
LDAModel`)上。
要设置k
,docConcentration
只需按以下类型:
val lda = new LDA()
val paramGrid = new ParamGridBuilder()
.addGrid(lda.k, Array(3, 5, 7))
.addGrid(lda.docConcentration, Array(Array(0.1, 0.4, 0.5)))
.build()
我们有调整参数(...)vocabSize
词汇大小由输入向量定义。它是无法调整的。
CrossValidator"如何知道"如何将网格值应用于各自的实体?
模型提供fit
方法,该方法需要dataset
和ParamMap
。 For example LDA
:
def fit(dataset: Dataset[_], paramMap: ParamMap): LDAModel
使用提供的参数图将单个模型拟合到输入数据。
此变种is used位于CrossValidator
。