我进行了tf-idf转换,现在我想从结果中获取键和值。
我正在使用以下udf代码获取值:
def extract_values_from_vector(vector):
return vector.values.tolist()
extract_values_from_vector_udf = udf(lambda vector:extract_values_from_vector(vector), ArrayType(DoubleType()))
extract = rescaledData.withColumn("extracted_keys", extract_keys_from_vector_udf("features"))
因此,如果稀疏向量看起来像: features = SparseVector(123241,{20672:4.4233,37393:0.0,109847:3.7096,118474:5.4042}))
我的摘录中的extracted_keys如下所示: [4.4233,0.0,3.7096,5.4042]
我的问题是,如何获得SparseVector词典中的密钥?例如键= [20672、37393、109847、118474]?
我正在尝试以下代码,但无法正常工作
def extract_keys_from_vector(vector):
return vector.indices.tolist()
extract_keys_from_vector_udf = spf.udf(lambda vector:extract_keys_from_vector(vector), ArrayType(DoubleType()))
它给我的结果是:[null,null,null,null]
有人可以帮忙吗? 提前非常感谢!
答案 0 :(得分:0)
因为答案在上面的评论中,所以我认为我会花这段时间(当然要等到拼写的时候)写下答案。
from pyspark.sql.types import *
from pyspark.sql import functions as F
def extract_keys_from_vector(vector):
return vector.indices.tolist()
feature_extract = F.UserDefinedFunction(lambda vector: extract_keys_from_vector(vector), ArrayType(IntegerType()))
df = df.withColumn("features", feature_extract(F.col("features")))