Spark - 如何将QuantileDiscretizer与RandomForestClassifier一起使用

时间:2018-03-23 21:03:21

标签: scala apache-spark apache-spark-ml

是否可以将 QuantileDiscretizer keep NaN值与 RandomForestClassifier 一起使用?

我收到的错误是这样的:

18/03/23 17:38:15 ERROR Executor: Exception in task 3.0 in stage 133.0 (TID 381)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])

实施例

这里的想法是创建一个数字列并使用分位数对其进行离散化,将无效数字(NaN)保存在特殊存储桶中。

import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler,
  QuantileDiscretizer}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassifier}

val tseq = Seq((0, "a", 1.0), (1, "b", 0.0), (2, "c", 2.0),
               (3, "a", 1.0), (4, "a", 3.0), (5, "c", Double.NaN))
val tdf = SparkInit.ss.createDataFrame(tseq).toDF("id", "category", "class")
val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")
val discr = new QuantileDiscretizer()
  .setInputCol("class")
  .setOutputCol("quant")
  .setNumBuckets(2)
  .setHandleInvalid("keep")
val assembler = new VectorAssembler()
  .setInputCols(Array("categoryIndex", "quant"))
  .setOutputCol("features")
val rf = new RandomForestClassifier()
  .setLabelCol("categoryIndex")
  .setFeaturesCol("features")
  .setNumTrees(3)
new Pipeline()
  .setStages(Array(indexer, discr, assembler, rf))
  .fit(tdf)
  .transform(tdf)
  .show()

在没有尝试适应随机森林的情况下,我得到了一个像这样的DataFrame:

+---+--------+-----+-------------+-----+---------+
| id|category|class|categoryIndex|quant| features|
+---+--------+-----+-------------+-----+---------+
|  0|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  1|       b|  0.0|          2.0|  0.0|[2.0,0.0]|
|  2|       c|  2.0|          1.0|  1.0|[1.0,1.0]|
|  3|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  4|       a|  3.0|          0.0|  1.0|[0.0,1.0]|
|  5|       c|  NaN|          1.0|  2.0|[1.0,2.0]|
+---+--------+-----+-------------+-----+---------+

如果我尝试适合模型,我会收到错误:

18/03/23 17:54:12 WARN DecisionTreeMetadata: DecisionTree reducing maxBins from 32 to 6 (= number of training instances)
18/03/23 17:54:12 WARN BlockManager: Putting block rdd_490_3 failed due to an exception
18/03/23 17:54:12 WARN BlockManager: Block rdd_490_3 could not be removed as it was not found on disk or in memory
18/03/23 17:54:12 ERROR Executor: Exception in task 3.0 in stage 143.0 (TID 414)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])
    at org.apache.spark.ml.tree.impl.TreePoint$.findBin(TreePoint.scala:124)
    at org.apache.spark.ml.tree.impl.TreePoint$.org$apache$spark$ml$tree$impl$TreePoint$$labeledPointToTreePoint(TreePoint.scala:93)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:73)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:72)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
    at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:216)

QuantileDiscretizer 是否会插入有关特殊额外存储桶的某种元数据?奇怪的是,我之前能够使用具有相同值的列来构建模型,但不会强制任何离散化。

更新

是的,列确实附加了元数据,它看起来像这样:

org.apache.spark.sql.types.Metadata = {"ml_attr":
   {"ord":true,
    "vals":["-Infinity, 5.0","5.0, 10.0","10.0, Infinity"],
    "type":"nominal"}
}

现在的问题可能是:如何正确设置元数据以包含Double.NaN

等值

1 个答案:

答案 0 :(得分:0)

我使用的解决方法只是从离散列中删除关联的元数据,让决策树实现决定如何处理数据。我认为该列实际上会成为一个数字列(例如[0, 1, 2, 2, 1]),但是,如果创建的类别太多,则可以再次对列进行离散化(查找参数maxBins)。

在我的情况下,删除元数据的最简单方法是在应用 QuantileDiscretizer fill DataFrame:

// Nothing is actually filled in my case, since there was no missing
// values before this operation.
df.na.fill(Double.NaN, Array("quant"))

我几乎可以肯定你也可以手动删除直接访问列对象的元数据。

更新

我们可以通过创建别名(reference)来更改列的元数据:

val metadata: Metadata = ...
df.select($"colA".as("colB", metadata))

This answer介绍了通过获取DataFrame架构的相应StructField获取列的元数据的方法。