Apache scala spark中数据框列中稀疏向量的大小

时间:2016-04-22 20:58:42

标签: scala apache-spark

我正在使用矢量汇编程序来转换数据帧。

var stringAssembler = new VectorAssembler().setInputCols(encodedstringColumns).setOutputCol("stringFeatures")
df = stringAssembler.transform(df)
**var stringVectorSize = df.select("stringFeatures").head.size**
var stringPca = new PCA().setInputCol("stringFeatures").setOutputCol("pcaStringFeatures").setK(stringVectorSize).fit(output)

现在,stringVectorSize将告诉PCA在执行pca时要保留多少列。 我试图从向量汇编程序获取输出稀疏向量的大小,但我的代码给出size = 1这是错误的。获取稀疏向量大小的正确代码是什么,它是数据帧列的一部分。

说白了

+-------------+------------+-------------+------------+---+-----------+---------------+-----------------+--------------------+
|PetalLengthCm|PetalWidthCm|SepalLengthCm|SepalWidthCm| Id|    Species|Species_Encoded|       Id_Encoded|      stringFeatures|
+-------------+------------+-------------+------------+---+-----------+---------------+-----------------+--------------------+
|          1.4|         0.2|          5.1|         3.5|  1|Iris-setosa|  (2,[0],[1.0])| (149,[91],[1.0])|(151,[91,149],[1....|
|          1.4|         0.2|          4.9|         3.0|  2|Iris-setosa|  (2,[0],[1.0])|(149,[119],[1.0])|(151,[119,149],[1...|
|          1.3|         0.2|          4.7|         3.2|  3|Iris-setosa|  (2,[0],[1.0])|(149,[140],[1.0])|(151,[140,149],[1...|

对于上述数据帧。我想提取stringFeatures稀疏向量(即151)的大小

1 个答案:

答案 0 :(得分:1)

如果您阅读DataFrame's documentation,您会注意到head方法返回Row。因此,您获取SparseVector的尺寸,而不是获得Row的尺寸。因此,要解决此问题,您必须提取存储在Row中的元素。

val row = df.select("stringFeatures").head 
val vector = vector(0).asInstanceOf[SparseVector]
val size = vector.size

例如:

import sqlContext.implicits._
import org.apache.spark.mllib.linalg.SparseVector

val df = sc.parallelize(Array(10,2,3,4)).toDF("n")
val pepe = udf((i: Int) => new SparseVector(i, Array(i-1), Array(i)))
val x = df.select(pepe(df("n")).as("n"))

x.show()

+---------------+
|              n|
+---------------+
|(10,[9],[10.0])|
|  (2,[1],[2.0])|
|  (3,[2],[3.0])|
|  (4,[3],[4.0])|
+---------------+

val y = x.select("n").head

y(0).asInstanceOf[SparseVector].size
res12: Int = 10