从Spark(pyspark)中的管道中的StringIndexer阶段获取标签

时间:2017-08-25 15:42:31

标签: python apache-spark pyspark

我正在使用Sparkpyspark我设置了pipeline一堆StringIndexer个对象,用于将字符串列编码为列指数:

indexers = [StringIndexer(inputCol=column, outputCol=column + '_index').setHandleInvalid('skip')
            for column in list(set(data_frame.columns) - ignore_columns)]
pipeline = Pipeline(stages=indexers)
new_data_frame = pipeline.fit(data_frame).transform(data_frame)

问题是,我需要在每个StringIndexer对象安装完成后获取标签列表。对于单个列和没有管道的单个StringIndexer,这是一项简单的任务。在labels

上拟合索引器后,我只能访问DataFrame属性
indexer = StringIndexer(inputCol="name", outputCol="name_index")
indexer_fitted = indexer.fit(data_frame)
labels = indexer_fitted.labels
new_data_frame = indexer_fitted.transform(data_frame)

然而,当我使用管道时,这似乎不可能,或者至少我不知道如何做到这一点。

所以我想我的问题归结为: 有没有办法访问每个列的索引过程中使用的标签?

或者我是否必须在此用例中抛弃管道,例如循环遍历StringIndexer对象列表并手动执行? (我确信这是可能的。但是使用管道会更好)

1 个答案:

答案 0 :(得分:5)

示例数据和Pipeline

from pyspark.ml.feature import StringIndexer, StringIndexerModel

df = spark.createDataFrame([("a", "foo"), ("b", "bar")], ("x1", "x2"))

pipeline = Pipeline(stages=[
    StringIndexer(inputCol=c, outputCol='{}_index'.format(c))
    for c in df.columns
])

model = pipeline.fit(df)

摘自stages

# Accessing _java_obj shouldn't be necessary in Spark 2.3+
{x._java_obj.getOutputCol(): x.labels 
for x in model.stages if isinstance(x, StringIndexerModel)}
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}

来自已转换的DataFrame的元数据:

indexed = model.transform(df)

{c.name: c.metadata["ml_attr"]["vals"]
for c in indexed.schema.fields if c.name.endswith("_index")}
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}