我有一个包含40列的SQL表:ID,Product,Product_ID,Date等,并且想要遍历所有列以获得不同的值。
Customer
表(示例):
ID Product
1 gadget
2 VR
2 AR
3 hi-fi
我尝试在循环遍历所有列的函数中使用dropDuplicates
,但是结果输出仅在每一列吐出一个不同的值,而不是所有可能的不同的值。
预期结果:
Column Value
ID 1
ID 2
ID 3
Product gadget
Product VR
Product AR
Product hi-fi
实际结果:
Column Value
ID 1
Product gadget
答案 0 :(得分:2)
想法是使用collect_set()来获取列中的不同元素,然后exploding
来获取数据框。
#All columns which need to be aggregated should be added here in col_list.
col_list = ['ID','Product']
exprs = [collect_set(x) for x in col_list]
让我们开始汇总。
from pyspark.sql.functions import lit , collect_set, explode, array, struct, col, substring, length, expr
df = spark.createDataFrame([(1,'gadget'),(2,'VR'),(2,'AR'),(3,'hi-fi')], schema = ['ID','Product'])
df = df.withColumn('Dummy',lit('Dummy'))
#While exploding later, the datatypes must be the same, so we have to cast ID as a String.
df = df.withColumn('ID',col('ID').cast('string'))
#Creating the list of distinct values.
df = df.groupby("Dummy").agg(*exprs)
df.show(truncate=False)
+-----+---------------+-----------------------+
|Dummy|collect_set(ID)|collect_set(Product) |
+-----+---------------+-----------------------+
|Dummy|[3, 1, 2] |[AR, VR, hi-fi, gadget]|
+-----+---------------+-----------------------+
def to_transpose(df, by):
# Filter dtypes and split into column names and type description
cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
# Spark SQL supports only homogeneous columns
assert len(set(dtypes)) == 1, "All columns have to be of the same type"
# Create and explode an array of (column_name, column_value) structs
kvs = explode(array([
struct(lit(c).alias("key"), col(c).alias("val")) for c in cols
])).alias("kvs")
return df.select(by + [kvs]).select(by + ["kvs.key", "kvs.val"])
df = to_transpose(df, ['Dummy']).drop('Dummy')
df.show()
+--------------------+--------------------+
| key| val|
+--------------------+--------------------+
| collect_set(ID)| [3, 1, 2]|
|collect_set(Product)|[AR, VR, hi-fi, g...|
+--------------------+--------------------+
df = df.withColumn('val', explode(col('val')))
df = df.withColumnRenamed('key', 'Column').withColumnRenamed('val', 'Value')
df = df.withColumn('Column', expr("substring(Column,13,length(Column)-13)"))
df.show()
+-------+------+
| Column| Value|
+-------+------+
| ID| 3|
| ID| 1|
| ID| 2|
|Product| AR|
|Product| VR|
|Product| hi-fi|
|Product|gadget|
+-------+------+
注意:所有不是字符串的列都应转换为df = df.withColumn('ID',col('ID').cast('string'))
之类的String。否则,您将得到错误。