PySpark在嵌套数组中反转StringIndexer

时间:2017-08-20 22:26:47

标签: python apache-spark pyspark apache-spark-sql apache-spark-ml

我正在使用PySpark使用ALS进行协同过滤。我的原始用户和项目ID是字符串,因此我使用Into<B>将它们转换为数字索引(PySpark的ALS模型要求我们这样做)。

在我安装模型之后,我可以为每个用户提供前三条建议:

fn a<A, B: Copy + Into<A>>(a: A, b: B) {}

StringIndexer数据框如下所示:

recs = (
    model
    .recommendForAllUsers(3)
)

我想用这个数据框创建一个巨大的JSOM转储,我可以这样:

recs

这些jsons的样本是:

+-----------+--------------------+
|userIdIndex|     recommendations|
+-----------+--------------------+
|       1580|[[10096,3.6725707...|
|       4900|[[10096,3.0137873...|
|       5300|[[10096,2.7274625...|
|       6620|[[10096,2.4493625...|
|       7240|[[10096,2.4928937...|
+-----------+--------------------+
only showing top 5 rows

root
 |-- userIdIndex: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- productIdIndex: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)

( recs .toJSON() .saveAsTextFile("name_i_must_hide.recs") ) { "userIdIndex": 1580, "recommendations": [ { "productIdIndex": 10096, "rating": 3.6725707 }, { "productIdIndex": 10141, "rating": 3.61542 }, { "productIdIndex": 11591, "rating": 3.536216 } ] } 密钥归因于userIdIndex转换。

如何获取这些列的原始值?我怀疑我必须使用productIdIndex变换器,但我无法弄清楚数据是如何嵌套在StringIndexer数据帧内的数组中的。

我尝试使用IndexToString评估程序(recs),但看起来这个评估程序不支持这些索引器。

干杯!

1 个答案:

答案 0 :(得分:1)

在这两种情况下,您都需要访问标签列表。可以使用StringIndexerModel

访问此内容
user_indexer_model = ...  # type: StringIndexerModel
user_labels = user_indexer_model.labels

product_indexer_model = ...  # type: StringIndexerModel
product_labels = product_indexer_model.labels

或列元数据。

对于userIdIndex,您只需应用IndexToString

from pyspark.ml.feature import IndexToString

user_id_to_label = IndexToString(
    inputCol="userIdIndex", outputCol="userId", labels=user_labels)
user_id_to_label.transform(recs)

对于建议,您需要udf或这样的表达式:

from pyspark.sql.functions import array, col, lit, struct

n = 3  # Same as numItems

product_labels_ = array(*[lit(x) for x in product_labels])
recommendations = array(*[struct(
    product_labels_[col("recommendations")[i]["productIdIndex"]].alias("productId"),
    col("recommendations")[i]["rating"].alias("rating")
) for i in range(n)])

recs.withColumn("recommendations", recommendations)