我的目标是构建一个多字符分类器。
我已经构建了一个用于特征提取的管道,它包括一个StringIndexer转换器,用于将每个类名映射到一个标签,该标签将用于分类器训练步骤。
管道安装在训练集上。
测试集必须由拟合的管道处理,以便提取相同的特征向量。
知道我的测试集文件具有与训练集相同的结构。这里可能的情况是在测试集中面对一个看不见的类名,在这种情况下,StringIndexer将无法找到标签,并且将引发异常。
这种情况有解决方案吗?或者我们如何避免这种情况发生?
答案 0 :(得分:14)
在Spark 1.6中可以解决这个问题。
这里是jira:https://issues.apache.org/jira/browse/SPARK-8764
以下是一个例子:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.setHandleInvalid("skip") // new method. values are "error" or "skip"
我开始使用它,但最终回到KrisP的第二个要点,将这个特定的Estimator安装到完整的数据集中。
转换IndexToString后,您将在管道中稍后需要这个。
以下是修改后的示例:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.fit(itemsDF) // Fit the Estimator and create a Model (Transformer)
... do some kind of classification ...
val categoryReverseIndexer = new IndexToString()
.setInputCol(classifier.getPredictionCol)
.setOutputCol("predictedCategory")
.setLabels(categoryIndexerModel.labels) // Use the labels from the Model
答案 1 :(得分:10)
没有好办法,我很害怕。任
StringIndexer
StringIndexer
符合列车和测试数据框架的联合,因此您可以放心所有标签都在那里以下是执行上述操作的示例代码:
// get training labels from original train dataframe
val trainlabels = traindf.select(colname).distinct.map(_.getString(0)).collect //Array[String]
// or get labels from a trained StringIndexer model
val trainlabels = simodel.labels
// define an UDF on your dataframe that will be used for filtering
val filterudf = udf { label:String => trainlabels.contains(label)}
// filter out the bad examples
val filteredTestdf = testdf.filter( filterudf(testdf(colname)))
// transform unknown value to some value, say "a"
val mapudf = udf { label:String => if (trainlabels.contains(label)) label else "a"}
// add a new column to testdf:
val transformedTestdf = testdf.withColumn( "newcol", mapudf(testdf(colname)))
答案 2 :(得分:1)
在我的情况下,我在一个大型数据集上运行spark ALS并且数据在所有分区都不可用,因此我必须适当地缓存()数据并且它像魅力一样工作
答案 3 :(得分:1)
对我而言,通过设置参数(https://issues.apache.org/jira/browse/SPARK-8764)完全忽略行并不是解决问题的可行方法。
我最终创建了自己的CustomStringIndexer转换器,它将为训练时未遇到的所有新字符串分配新值。您也可以通过更改spark功能代码的相关部分来执行此操作(只需删除if条件,显式检查此项并使其返回数组的长度)并重新编译jar。
不是一个简单的修复方法,但肯定是一个修复方法。
我记得在JIRA中也看到了一个错误:https://issues.apache.org/jira/browse/SPARK-17498
它设置为随Spark 2.2发布。我必须等待:S