继我之前的问题Convert a Spark Vector of features into an array之后,我取得了进展:
def extractUdf = udf((v: SDV) => v.toArray)
val temp: DataFrame = dataWithFeatures.withColumn("extracted_features", extractUdf($"features"))
temp.printSchema()
val featuresArray1: Array[Double] = temp.rdd.map(r => r.getAs[Double](0)).collect
val featuresArray2: Array[Double] = temp.rdd.map(r => r.getAs[Double](1)).collect
val featuresArray3: Array[Double] = temp.rdd.map(r => r.getAs[Double](2)).collect
val allfeatures: Array[Array[Double]] = Array(featuresArray1, featuresArray2, featuresArray3)
val flatfeatures: Array[Double] = allfeatures.flatten
这似乎给出了我想要的结果。 extractUdf
函数转换功能:Vector into extracted_feature:
|-- features: vector (nullable = true)
|-- extracted_features: array (nullable = true)
| |-- element: double (containsNull = false)
但是,我不明白为什么接下来的3行代码(即数组featuresArray1,featuresArray2,featuresArray3)正在拾取extracted_features
而不是temp
中的任何其他列(如{{{ 1}})例如,以及如何获取数组的索引(0,1,2)直接引用特征的数量并且不是硬编码的。谢谢你的帮助!
答案 0 :(得分:3)
假设您有dataframe
+---+-------------+
|id |features |
+---+-------------+
|1 |[1.0,2.0,3.0]|
|2 |[3.0,4.0,8.0]|
+---+-------------+
schema
root
|-- id: integer (nullable = false)
|-- features: vector (nullable = true)
您已通过
将vector
功能提取到Array
import org.apache.spark.sql.functions._
import org.apache.spark.mllib.linalg.DenseVector
def extractUdf = udf((v: DenseVector) => v.toArray)
val temp = dataWithFeatures.withColumn("extracted_features", extractUdf($"features"))
会给出
+---+-------------+------------------+
|id |features |extracted_features|
+---+-------------+------------------+
|1 |[1.0,2.0,3.0]|[1.0, 2.0, 3.0] |
|2 |[3.0,4.0,8.0]|[3.0, 4.0, 8.0] |
+---+-------------+------------------+
root
|-- id: integer (nullable = false)
|-- features: vector (nullable = true)
|-- extracted_features: array (nullable = true)
| |-- element: double (containsNull = false)
现在引用extracted_features
Array
列中的元素与 scala 中的其他array
类型相同。所以你可以做到
temp.withColumn("firstValue", $"extracted_features"(0))
.withColumn("secondValue", $"extracted_features"(1))
.withColumn("thirdValue", $"extracted_features"(2))
会给你
+---+-------------+------------------+----------+-----------+----------+
|id |features |extracted_features|firstValue|secondValue|thirdValue|
+---+-------------+------------------+----------+-----------+----------+
|1 |[1.0,2.0,3.0]|[1.0, 2.0, 3.0] |1.0 |2.0 |3.0 |
|2 |[3.0,4.0,8.0]|[3.0, 4.0, 8.0] |3.0 |4.0 |8.0 |
+---+-------------+------------------+----------+-----------+----------+
root
|-- id: integer (nullable = false)
|-- features: vector (nullable = true)
|-- extracted_features: array (nullable = true)
| |-- element: double (containsNull = false)
|-- firstValue: double (nullable = true)
|-- secondValue: double (nullable = true)
|-- thirdValue: double (nullable = true)
我希望答案很有帮助