如何从pyspark SparseVector获取密钥

时间:2019-01-01 16:34:35

标签: pyspark tf-idf

我进行了tf-idf转换,现在我想从结果中获取键和值。

我正在使用以下udf代码获取值:

def extract_values_from_vector(vector):
    return vector.values.tolist()

extract_values_from_vector_udf = udf(lambda vector:extract_values_from_vector(vector), ArrayType(DoubleType()))

extract = rescaledData.withColumn("extracted_keys", extract_keys_from_vector_udf("features"))

因此,如果稀疏向量看起来像: features = SparseVector(123241,{20672:4.4233,37393:0.0,109847:3.7096,118474:5.4042}))

我的摘录中的

extracted_keys如下所示: [4.4233,0.0,3.7096,5.4042]

我的问题是,如何获得SparseVector词典中的密钥?例如键= [20672、37393、109847、118474]?

我正在尝试以下代码,但无法正常工作

def extract_keys_from_vector(vector):
    return vector.indices.tolist()
extract_keys_from_vector_udf = spf.udf(lambda vector:extract_keys_from_vector(vector), ArrayType(DoubleType()))

它给我的结果是:[null,null,null,null]

有人可以帮忙吗? 提前非常感谢!

1 个答案:

答案 0 :(得分:0)

因为答案在上面的评论中,所以我认为我会花这段时间(当然要等到拼写的时候)写下答案。

from pyspark.sql.types import *
from pyspark.sql import functions as F

def extract_keys_from_vector(vector):
    return vector.indices.tolist()

feature_extract = F.UserDefinedFunction(lambda vector: extract_keys_from_vector(vector), ArrayType(IntegerType()))

df = df.withColumn("features", feature_extract(F.col("features")))