ApacheSparkML StringIndexer吃掉我的列

时间:2017-04-27 15:43:27

标签: apache-spark apache-spark-sql apache-spark-ml

将StringIndexer应用于包含以下列的df_notnull(DataFrame对象)时:

scala> df_notnull.printSchema
root
 |-- L0_S22_F545: string (nullable = true)
 |-- L0_S0_F0: double (nullable = true)
 |-- L0_S0_F2: double (nullable = true)
 |-- L0_S0_F4: double (nullable = true)

只剩下那些:

scala> indexed.printSchema
root
 |-- L0_S22_F545: string (nullable = true)
 |-- L0_S22_F545Index: double (nullable = true)

这是我的代码:

:paste
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}

val indexer = new StringIndexer()
  .setInputCol("L0_S22_F545")
  .setOutputCol("L0_S22_F545Index")

val indexed = indexer.fit(df_notnull).transform(df_notnull)
indexed.printSchema

我想保留所有列,只添加一些新列。我做错了什么?

1 个答案:

答案 0 :(得分:0)

找到解决方案here。实际上变压器不应该单独使用,而应该与管道一起使用 - 然后保留列:

import org.apache.spark.ml.Pipeline
val transformers = Array(
    indexer,
    encoder
)

var pipeline = new Pipeline().setStages(transformers).fit(df_notnull)

var transformed = pipeline.transform(df_notnull)

结果如下:

scala> transformed.show
+-----------+--------+--------+--------+----------------+--------------+        
|L0_S22_F545|L0_S0_F0|L0_S0_F2|L0_S0_F4|L0_S22_F545Index|L0_S22_F545Vec|
+-----------+--------+--------+--------+----------------+--------------+
|         NA|    0.03|  -0.034|  -0.197|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|   0.088|   0.086|   0.003|             0.0|(13,[0],[1.0])|
|         NA|  -0.036|  -0.064|   0.294|             0.0|(13,[0],[1.0])|
|         NA|  -0.055|  -0.086|   0.294|             0.0|(13,[0],[1.0])|
|         NA|   0.003|   0.019|   0.294|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|  -0.016|  -0.041|  -0.179|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|   0.016|   0.093|  -0.015|             0.0|(13,[0],[1.0])|
|         NA|  -0.062|  -0.153|  -0.197|             0.0|(13,[0],[1.0])|
|         NA|  -0.075|  -0.093|   0.367|             0.0|(13,[0],[1.0])|
|         NA|  -0.003|  -0.093|  -0.161|             0.0|(13,[0],[1.0])|
|         NA|  -0.016|  -0.138|  -0.197|             0.0|(13,[0],[1.0])|
|         NA|   0.252|    0.25|   0.003|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|  -0.016|  -0.041|   0.003|             0.0|(13,[0],[1.0])|
|         NA|     0.0|     0.0|     0.0|             0.0|(13,[0],[1.0])|
|         NA|   0.088|   0.033|    0.33|             0.0|(13,[0],[1.0])|
+-----------+--------+--------+--------+----------------+--------------+
only showing top 20 rows