适合管道和处理数据

时间:2017-02-19 16:09:26

标签: scala apache-spark pipeline

我有一个包含文字的文件。我想要做的是使用管道来标记文本,删除停用词并产生2克。

到目前为止我做了什么:

第1步:阅读文件

val data = sparkSession.read.text("data.txt").toDF("text")

第2步:构建管道

val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")

val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
val model = pipeline.fit(data)

我知道pipeline.fit(data)会产生PipelineModel,但我不知道如何使用PipelineModel

非常感谢任何帮助。

1 个答案:

答案 0 :(得分:2)

当您运行val model = pipeline.fit(data)代码时,所有Estimator阶段(即:机器学习任务,如分类,回归,群集等)都适合数据,Transformer阶段是创建。您只有Transformer个阶段,因为您正在此管道中创建要素。

要执行您的模型,现在只包含Transformer个阶段,您需要运行val results = model.transform(data)。这将针对您的数据框执行每个Transformer阶段。因此,在model.transform(data)进程结束时,您将拥有一个由原始行,Tokenizer输出,StopWordsRemover输出以及最终NGram结果组成的数据帧。

完成功能创建后,发现前5个ngrams可以通过SparkSQL查询执行。首先展开ngram列,然后按ngrams计数group,按计数列以降序排序,然后执行show(5)。或者,您可以使用"LIMIT 5方法代替show(5)

另外,您应该将对象名称更改为不是标准类名称的名称。否则你将会出现一个不确定的范围错误。

<强> CODE:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.Tokenizer
import org.apache.spark.sql.SparkSession._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.NGram
import org.apache.spark.ml.feature.StopWordsRemover
import org.apache.spark.ml.{Pipeline, PipelineModel}

object NGramPipeline {
    def main() {
        val sparkSession = SparkSession.builder.appName("NGram Pipeline").getOrCreate()

        val sc = sparkSession.sparkContext

        val data = sparkSession.read.text("quangle.txt").toDF("text")

        val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
        val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
        val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")

        val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
        val model = pipeline.fit(data)

        val results = model.transform(data)

        val explodedNGrams = results.withColumn("explNGrams", explode($"ngrams"))
        explodedNGrams.groupBy("explNGrams").agg(count("*") as "ngramCount").orderBy(desc("ngramCount")).show(10,false)

    }
}
NGramPipeline.main()



输出:

+-----------------+----------+
|explNGrams       |ngramCount|
+-----------------+----------+
|quangle wangle   |9         |
|wangle quee.     |4         |
|'mr. quangle     |3         |
|said, --         |2         |
|wangle said      |2         |
|crumpetty tree   |2         |
|crumpetty tree,  |2         |
|quangle wangle,  |2         |
|crumpetty tree,--|2         |
|blue babboon,    |2         |
+-----------------+----------+
only showing top 10 rows

请注意,存在导致行重复的语法(逗号,短划线等)。在执行ngrams时,过滤我们的语法通常是一个好主意。您通常可以使用正则表达式执行此操作。