在pyspark SQL DataFrame中乘以稀疏向量的行

时间:2017-05-19 07:58:53

标签: python apache-spark pyspark pyspark-sql

我在SQL数据框中乘以列的元素时遇到了困难。

sv1 = Vectors.sparse(3, [0, 2], [1.0, 3.0])
sv2 = Vectors.sparse(3, [0, 1], [2.0, 4.0])

def xByY(x,y):
  return np.multiply(x,y)

print(xByY(sv1, sv2))

以上作品。

但是下面没有。

xByY_udf = udf(xByY)

tempDF = sqlContext.createDataFrame([(sv1, sv2), (sv1, sv2)], ('v1', 'v2'))
tempDF.show()

print(tempDF.select(xByY_udf('v1', 'v2')).show())

非常感谢!

1 个答案:

答案 0 :(得分:2)

如果您希望udf返回SparseVector,我们首先需要修改您的函数输出,然后将udf的输出架构设置为VectorUDT()

要声明SparseVector,我们需要原始数组的大小,以及索引 非零元素。如果乘法的中间结果是len(),我们可以使用list和列表推导找到这些:

from pyspark.ml.linalg import Vectors, VectorUDT

def xByY(x,y):
  res = np.multiply(x,y).tolist()
  vec_args =  len(res), [i for i,x in enumerate(res) if x != 0], [x for x in res if x != 0] 
  return Vectors.sparse(*vec_args)  

现在我们可以声明我们的udf并测试它:

xByY_udf = udf(xByY, VectorUDT())
tempDF.select(xByY_udf('v1', 'v2')).show()
+-------------+
| xByY(v1, v2)|
+-------------+
|(3,[0],[2.0])|
|(3,[0],[2.0])|
+-------------+