我有一个PySpark DataFrame
Col1 Col2 Col3
0.1 0.2 0.3
我想获取列名称,其中至少有一行符合条件,例如行大于0.1
我的预期结果应该是这种情况:
[Co2 , Co3]
我无法提供任何代码,因为我真的不知道该怎么做。
答案 0 :(得分:3)
只需count
项满足谓词(内部select
)并处理结果:
from pyspark.sql.functions import col, count, when
[c for c, v in df.select([
count(when(col(c) > 0.1, 1)).alias(c) for c in df.columns
]).first().asDict().items() if v]
一步一步:
汇总(DataFrame
- > DatFrame
):
df = sc.parallelize([(0.1, 0.2, 0.3)]).toDF()
counts = df.select([
count(when(col(c) > 0.1, 1)).alias(c) for c in df.columns
])
DataFrame[_1: bigint, _2: bigint, _3: bigint]
collect
first
Row
:
a_row = counts.first()
Row(_1=0, _2=1, _3=1)
转换为Python dict
:
a_dict = a_row.asDict()
{'_1': 0, '_2': 1, '_3': 1}
迭代它的项目,保持键,当值是真的时候:
[c for c, v in a_dict.items() if v]
或明确检查计数:
[c for c, v in a_dict.items() if v > 0]