如何将Spark数据帧列从Array [Int]转换为linalg.Vector?

时间:2017-10-17 18:57:13

标签: scala apache-spark spark-dataframe apache-spark-ml

我有一个数据帧df,如下所示:

+--------+--------------------+
| user_id|        is_following|
+--------+--------------------+
|       1|[2, 3, 4, 5, 6, 7]  |
|       2|[20, 30, 40, 50]    |
+--------+--------------------+

我可以确认这有架构:

root
 |-- user_id: integer (nullable = true)
 |-- is_following: array (nullable = true)
 |    |-- element: integer (containsNull = true)

我想使用Spark的ML例程(如LDA)对此进行一些机器学习,要求我将is_following列转换为linalg.Vector(不是Scala向量) 。当我尝试通过

这样做时
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors

val assembler = new VectorAssembler().setInputCols(Array("is_following")).setOutputCol("features")
val output = assembler.transform(df)

然后我收到以下错误:

java.lang.IllegalArgumentException: Data type ArrayType(IntegerType,true) is not supported.

如果我正确地解释了这一点,我会从中删除我需要将类型从整数转换为其他类型。 (Double?String?)

我的问题是将此数组转换为适合ML管道的矢量化的最佳方法是什么?

编辑:如果有帮助,我不必这样构建数据框架。我可以改为:

+--------+------------+
| user_id|is_following|
+--------+------------+
|       1|           2|
|       1|           3|
|       1|           4|
|       1|           5|
|       1|           6|
|       1|           7|
|       2|          20|
|     ...|         ...|
+--------+------------+

2 个答案:

答案 0 :(得分:1)

将数组转换为linalg.Vector并同时将整数转换为双精度的简单解决方案是使用UDF

使用您的数据框:

val spark = SparkSession.builder.getOrCreate()
import spark.implicits._

val df = spark.createDataFrame(Seq((1, Array(2,3,4,5,6,7)), (2, Array(20,30,40,50))))
  .toDF("user_id", "is_following")

val convertToVector = udf((array: Seq[Int]) => {
  Vectors.dense(array.map(_.toDouble).toArray)
})

val df2 = df.withColumn("is_following", convertToVector($"is_following"))
此处导入

spark.implicits._以允许使用$col()'代替。

打印df2数据框将显示想要的结果:

+-------+-------------------------+
|user_id|is_following             |
+-------+-------------------------+
|1      |[2.0,3.0,4.0,5.0,6.0,7.0]|
|2      |[20.0,30.0,40.0,50.0]    |
+-------+-------------------------+

模式:

root
 |-- user_id: integer (nullable = false)
 |-- is_following: vector (nullable = true)

答案 1 :(得分:0)

因此,您的初始输入可能比您的转换输入更适合。 Spark的VectorAssembler要求所有列都是双打,而不是双打的数组。由于不同的用户可以跟随不同数量的人,你当前的结构可能是好的,你只需要将is_following转换为Double,你可以使用Spark的VectorIndexer https://spark.apache.org/docs/2.1.0/ml-features.html#vectorindexer来实现这一点,或者只是在SQL中手动执行。

所以tl; dr是 - 类型错误是因为Spark的Vector只支持双打(这在不太遥远的未来可能会改变图像数据但不适合你的用例)和你'替代结构可能实际上更适合(没有分组的那个)。

您可能会发现Spark文档中的协作过滤示例对您的进一步冒险有用 - https://spark.apache.org/docs/latest/ml-collaborative-filtering.html。祝你好运并享受Spark ML的乐趣:)

编辑:

我注意到你说你想在输入上做LDA,所以让我们看一下如何为这种格式准备数据。对于LDA输入,您可能需要考虑使用CountVectorizer(请参阅https://spark.apache.org/docs/2.1.0/ml-features.html#countvectorizer