Spark,ML,StringIndexer:处理看不见的标签

时间:2016-01-08 16:20:05

标签: apache-spark apache-spark-ml

我的目标是构建一个多字符分类器。

我已经构建了一个用于特征提取的管道,它包括一个StringIndexer转换器,用于将每个类名映射到一个标签,该标签将用于分类器训练步骤。

管道安装在训练集上。

测试集必须由拟合的管道处理,以便提取相同的特征向量。

知道我的测试集文件具有与训练集相同的结构。这里可能的情况是在测试集中面对一个看不见的类名,在这种情况下,StringIndexer将无法找到标签,并且将引发异常。

这种情况有解决方案吗?或者我们如何避免这种情况发生?

4 个答案:

答案 0 :(得分:14)

在Spark 1.6中可以解决这个问题。

这里是jira:https://issues.apache.org/jira/browse/SPARK-8764

以下是一个例子:

val categoryIndexerModel = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("indexedCategory")
  .setHandleInvalid("skip") // new method.  values are "error" or "skip"

我开始使用它,但最终回到KrisP的第二个要点,将这个特定的Estimator安装到完整的数据集中。

转换IndexToString后,您将在管道中稍后需要这个。

以下是修改后的示例:

val categoryIndexerModel = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("indexedCategory")
  .fit(itemsDF) // Fit the Estimator and create a Model (Transformer)

... do some kind of classification ...

val categoryReverseIndexer = new IndexToString()
  .setInputCol(classifier.getPredictionCol)
  .setOutputCol("predictedCategory")
  .setLabels(categoryIndexerModel.labels) // Use the labels from the Model

答案 1 :(得分:10)

没有好办法,我很害怕。任

  • 在应用StringIndexer
  • 之前,使用未知标签过滤掉测试示例
  • 或使StringIndexer符合列车和测试数据框架的联合,因此您可以放心所有标签都在那里
  • 或将具有未知标签的测试示例案例转换为已知标签

以下是执行上述操作的示例代码:

// get training labels from original train dataframe
val trainlabels = traindf.select(colname).distinct.map(_.getString(0)).collect  //Array[String]
// or get labels from a trained StringIndexer model
val trainlabels = simodel.labels 

// define an UDF on your dataframe that will be used for filtering
val filterudf = udf { label:String => trainlabels.contains(label)}

// filter out the bad examples 
val filteredTestdf = testdf.filter( filterudf(testdf(colname)))

// transform unknown value to some value, say "a"
val mapudf = udf { label:String => if (trainlabels.contains(label)) label else "a"}

// add a new column to testdf: 
val transformedTestdf = testdf.withColumn( "newcol", mapudf(testdf(colname)))

答案 2 :(得分:1)

在我的情况下,我在一个大型数据集上运行spark ALS并且数据在所有分区都不可用,因此我必须适当地缓存()数据并且它像魅力一样工作

答案 3 :(得分:1)

对我而言,通过设置参数(https://issues.apache.org/jira/browse/SPARK-8764)完全忽略行并不是解决问题的可行方法。

我最终创建了自己的CustomStringIndexer转换器,它将为训练时未遇到的所有新字符串分配新值。您也可以通过更改spark功能代码的相关部分来执行此操作(只需删除if条件,显式检查此项并使其返回数组的长度)并重新编译jar。

不是一个简单的修复方法,但肯定是一个修复方法。

我记得在JIRA中也看到了一个错误:https://issues.apache.org/jira/browse/SPARK-17498

它设置为随Spark 2.2发布。我必须等待:S