如何创建Spark DataFrame,其中一列具有数组值?

时间:2017-10-04 13:39:05

标签: pyspark pyspark-sql

我有以下代码,为某些用户创建建议:

als = ALS(rank=4, maxIter=3, regParam=0.01, implicitPrefs=True, alpha=40,
              userCol="userId", itemCol="itemId", ratingCol="rating")
model = als.fit(train_df)

现在我想把我的测试DataFrame扩展为每个用户的top-k预测,作为一个数组。例如,如果我的测试DataFrame是:

userId    ProductId
1         2
1         3
2         7
3         4

它应该变成:

userId    ProductId    Predictions
1         2            [2, 4, 6, 8 , 3, 9, 10, 3, 50, 14]
1         3            ...
2         7            ...
3         4            ...

我正在使用这样的udf

predictions_udf = udf(user_prediction, ArrayType(DoubleType()))
predictions_df = test_df.select('*').withColumn('Predictions', predictions_udf())

我应该如何实现此user_prediction方法。我已尝试过所有但无法传递模型或DataFrame作为方法参数,并且无法将建议作为ArrayType()返回。

0 个答案:

没有答案