有没有办法从PySpark PipelineModel中的各个阶段访问方法?

时间:2016-07-29 17:42:06

标签: python apache-spark pyspark apache-spark-mllib apache-spark-ml

我已经在Spark 2.0中创建了PipelineModel来做LDA(通过PySpark API):

def create_lda_pipeline(minTokenLength=1, minDF=1, minTF=1, numTopics=10, seed=42, pattern='[\W]+'):
    """
    Create a pipeline for running an LDA model on a corpus. This function does not need data and will not actually do
    any fitting until invoked by the caller.
    Args:
        minTokenLength:
        minDF: minimum number of documents word is present in corpus
        minTF: minimum number of times word is found in a document
        numTopics:
        seed:
        pattern: regular expression to split words

    Returns:
        pipeline: class pyspark.ml.PipelineModel
    """
    reTokenizer = RegexTokenizer(inputCol="text", outputCol="tokens", pattern=pattern, minTokenLength=minTokenLength)
    cntVec = CountVectorizer(inputCol=reTokenizer.getOutputCol(), outputCol="vectors", minDF=minDF, minTF=minTF)
    lda = LDA(k=numTopics, seed=seed, optimizer="em", featuresCol=cntVec.getOutputCol())
    pipeline = Pipeline(stages=[reTokenizer, cntVec, lda])
    return pipeline

我想使用带有LDAModel.logPerplexity()方法的训练模型计算数据集的困惑度,所以我尝试运行以下内容:

try:
    training = get_20_newsgroups_data(test_or_train='test')
    pipeline = create_lda_pipeline(numTopics=20, minDF=3, minTokenLength=5)
    model = pipeline.fit(training)  # train model on training data
    testing = get_20_newsgroups_data(test_or_train='test')
    perplexity = model.logPerplexity(testing)
    pprint(perplexity)

这只会产生以下AttributeError

'PipelineModel' object has no attribute 'logPerplexity'

我理解为什么会出现此错误,因为logPerplexity方法属于LDAModel,而不是PipelineModel,但我想知道是否有办法从该阶段访问该方法。

2 个答案:

答案 0 :(得分:8)

管道中的所有变换器都存储在stages属性中。提取stages,选择最后一个,然后您就可以了:

model.stages[-1].logPerplexity(testing)

答案 1 :(得分:0)

我遇到了pipeline.stages不起作用的问题-pipeline.stages被视为一个参数。 在这种情况下,请使用

pipeline.getStages()

,您将拥有阶段的列表,就像pipeline.stage在大多数情况下一样。