我有一个数据框,其中包含以下数据:
df.show()
+-----+------+--------+
| id_A| idx_B| B_value|
+-----+------+--------+
| a| 0| 7|
| b| 0| 5|
| b| 2| 2|
+-----+------+--------+
假设B总共有3个可能的索引,我想创建一个表,它将所有索引和值合并到一个看起来像这样的列表(或numpy数组)中:
final_df.show()
+-----+----------+
| id_A| B_values|
+-----+----------+
| a| [7, 0, 0]|
| b| [5, 0, 2]|
+-----+----------+
我已经设法做到这一点:
from pyspark.sql import functions as f
temp_df = df.withColumn('B_tuple', f.struct(df['idx_B'], df['B_value']))\
.groupBy('id_A').agg(f.collect_list('B_tuple').alias('B_tuples'))
temp_df.show()
+-----+-----------------+
| id_A| B_tuples|
+-----+-----------------+
| a| [[0, 7]]|
| b| [[0, 5], [2, 2]]|
+-----+-----------------+
但是现在我无法运行适当的udf
函数将temp_df
转换为final_df
。
有没有更简单的方法?
如果没有,我应该使用什么适当的函数来完成转换?
答案 0 :(得分:1)
所以我找到了解决方法,
def create_vector(tuples_list, size):
my_list = [0] * size
for x in tuples_list:
my_list[x["idx_B"]] = x["B_value"]
return my_list
create_vector_udf = f.udf(create_vector, ArrayType(IntegerType()))
final_df = temp_df.with_column('B_values', create_vector_udf(temp_df['B_tuples'])).select(['id_A', 'B_values'])
final_df.show()
+-----+----------+
| id_A| B_values|
+-----+----------+
| a| [7, 0, 0]|
| b| [5, 0, 2]|
+-----+----------+
答案 1 :(得分:0)
如果您已经知道数组的size
,则可以在没有udf
的情况下进行操作。
利用pivot()
的可选第二个参数:values
。这需要一个
将转换为输出DataFrame中的列的值列表
这样groupBy
id_A
列,然后在idx_B
列上旋转DataFrame。由于并非所有索引都可能存在,因此您可以将range(size)
作为values
参数传递。
import pyspark.sql.functions as f
size = 3
df = df.groupBy("id_A").pivot("idx_B", values=range(size)).agg(f.first("B_value"))
df = df.na.fill(0)
df.show()
#+----+---+---+---+
#|id_A| 0| 1| 2|
#+----+---+---+---+
#| b| 5| 0| 2|
#| a| 7| 0| 0|
#+----+---+---+---+
数据中不存在的索引将默认为null
,因此我们将na.fill(0)
称为默认值。
一旦您使用这种格式的数据,您只需要从列中创建一个数组:
df.select("id_A", f.array([f.col(str(i)) for i in range(size)]).alias("B_values")).show()
#+----+---------+
#|id_A| B_values|
#+----+---------+
#| b|[5, 0, 2]|
#| a|[7, 0, 0]|
#+----+---------+