在将RandomForestClassifier用于二进制分类并在数据集上进行预测后,我获得了带有标签,预测和概率列的变换数据框 df 。
目标:
我想创建一个新列“ prob_flag”,它是预测标签“ 1”的概率。它是包含概率的数组的第二个元素(本身是第一个数组的第三个元素)。
我调查了similar topics,但遇到了这些主题中未遇到的错误。
df.show()
label prediction probability
0 0 [1,2,[],[0.7558548984793847,0.2441451015206153]]
0 0 [1,2,[],[0.5190322149055472,0.4809677850944528]]
0 1 [1,2,[],[0.4884140358521083,0.5115859641478916]]
0 1 [1,2,[],[0.4884140358521083,0.5115859641478916]]
1 1 [1,2,[],[0.40305518381637956,0.5969448161836204]]
1 1 [1,2,[],[0.40570407426458577,0.5942959257354141]]
# The probability column is VectorUDT and looks like an array of dim 4 that contains probabilities of predicted variables I want to retrieve
df.schema
StructType(List(StructField(label,DoubleType,true),StructField(prediction,DoubleType,false),StructField(probability,VectorUDT,true)))
# I tried this:
import pyspark.sql.functions as f
df.withColumn("prob_flag", f.array([f.col("probability")[3][1])).show()
"Can't extract value from probability#6225: need struct type but got struct<type:tinyint,size:int,indices:array<int>,values:array<double>>;"
我想创建一个新列“ prob_flag”,它是预测标签“ 1”的概率。它是包含概率的数组的第二个数字,例如0.24、0.48、0.51、0.51、0.59、0.59。
答案 0 :(得分:1)
不幸的是,您无法像提取VectorUDT那样将其提取为ArrayType。
您必须使用udf代替:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import udf, col
def extract_prob(v):
try:
return float(v[1]) # Your VectorUDT is of length 2
except ValueError:
return None
extract_prob_udf = udf(extract_prob, DoubleType())
df2 = df.withColumn("prob_flag", extract_prob_udf(col("probability")))