使用VectorAssembler

时间:2017-01-11 14:40:45

标签: apache-spark pipeline apache-spark-ml

使用sparks矢量汇编程序需要预先定义要组装的列。

但是,如果在前面的步骤将修改数据帧的列的管道中使用向量汇编程序,我如何指定列而无需手动对所有值进行硬编码?

df.columns 包含正确的值时,当构造函数被调用vector-assembler时,我目前看不到另一种方法来处理它或拆分管道 - 这很糟糕因为CrossValidator将不再正常工作。

val vectorAssembler = new VectorAssembler()
    .setInputCols(df.columns
      .filter(!_.contains("target"))
      .filter(!_.contains("idNumber")))
    .setOutputCol("features")

修改

的初始df
---+------+---+-
|foo|   id|baz|
+---+------+---+
|  0| 1    |  A|
|  1|2     |  A|
|  0| 3    |  null|
|  1| 4    |  C|
+---+------+---+

将转换如下。您可以看到,对于最常见的原始列以及某些特征(例如,如此处所述isA如果baz为A,则为1,否则为0,如果为null则为N

+---+------+---+-------+
|foo|id    |baz| isA    |
+---+------+---+-------+
|  0| 1    |  A| 1      |
|  1|2     |  A|1       |
|  0| 3    |   A|    n  |
|  1| 4    |  C|    0   |
+---+------+---+-------+

稍后在管道中,使用stringIndexer使数据适合ML / vectorAssembler。

isA不存在于原始df中,但不会出现在"只有"输出列除了foo和id列之外,该帧中的所有列都应该由向量汇编器转换。

我希望现在更清楚了。

2 个答案:

答案 0 :(得分:4)

如果我理解你的问题,那么答案就会非常简单直接,你只需要使用上一个变换器中的.getOutputCol

示例(来自官方文档):

// Prepare training documents from a list of (id, text, label) tuples.
val training = spark.createDataFrame(Seq(
  (0L, "a b c d e spark", 1.0),
  (1L, "b d", 0.0),
  (2L, "spark f g h", 1.0),
  (3L, "hadoop mapreduce", 0.0)
)).toDF("id", "text", "label")

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
  .setInputCol("text")
  .setOutputCol("words")
val hashingTF = new HashingTF()
  .setNumFeatures(1000)
  .setInputCol(tokenizer.getOutputCol) // <==== Using the tokenizer output column
  .setOutputCol("features")
val lr = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.001)
val pipeline = new Pipeline()
  .setStages(Array(tokenizer, hashingTF, lr))

现在让我们将其应用于考虑另一个假设列alpha的VectorAssembler:

val assembler = new VectorAssembler()
  .setInputCols(Array("alpha", tokenizer.getOutputCol)
  .setOutputCol("features")

答案 1 :(得分:1)

我创建了一个自定义矢量汇编程序(原始版本的1:1副本),然后将其更改为包括除了一些传递以排除的列之外的所有列。

修改

让它更清晰

def setInputColsExcept(value: Array[String]): this.type = set(inputCols, value)

指定应排除哪些列。然后

val remainingColumns = dataset.columns.filter(!$(inputCols).contains(_))
转换方法中的

是对所需列进行过滤。