访问Spark DataFrame中矢量的元素

时间:2018-05-10 08:47:49

标签: python apache-spark vector pyspark spark-dataframe

我想访问下面显示的spark数据框中的向量元素: -

name_cv

(262144,[88783,143375,220659,228248],[1.0,1.0,1.0,1.0])

(262144,[220659],[1.0])

(262144,[75742],[1.0])

(262144,[68369,95745,107911,224494],[1.0,1.0,1.0,1.0])
& so on

我想访问第一行的88783,143375,220659,228248以及1.0,1.0,1.0,1.0等等...(对于数据帧中的其他行)。我需要这些元素,以便我可以计算元素的平均值。请帮帮我。

我已按照以下StackOverflow帖子中引用的步骤操作,但它们对我没有用。 How to access element of a VectorUDT column in a Spark DataFrame?

Access element of a vector in a Spark DataFrame (Logistic Regression probability vector)

我尝试添加此UDF: -

from pyspark.sql.types import DoubleType
from pyspark.sql.functions import lit, udf

def ith_(v, i):
    try:
        return float(v[i])
    except ValueError:
        return None

ith = udf(ith_, DoubleType())

appended.select(ith("name_cv", lit(1))).show()

输出结果为: -

+----------------+
|ith_(name_cv, 1)|
+----------------+
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
|             0.0|
+----------------+

输出应为

+----------------+
    |ith_(name_cv, 1)|
    +----------------+
    |[88783,143375,220659,228248]|
    |[220659]|
    |[75742]|
    |[220659]|
    & so on
    +----------------+

谢谢!

1 个答案:

答案 0 :(得分:0)

  

我需要这些元素,以便我可以计算元素的平均值。

为什么不直接这样做:

from pyspark.sql.functions import udf

@udf("double")
def vector_mean(v):
    if v is None:
        return None 
    elif hasattr(v, "values"):
        return v.values.mean().tolist()
    else:
        return v.mean().tolist()

appended.select(vector_mean("name_cv"))

否则:

@udf("array<double>")
def values(v):
    if v is None:
        return None 
    elif hasattr(v, "values"):
        return v.values.tolist()
    else:
        return v.array.tolist()