如何通过索引而不是名称来获取列?

时间:2017-12-28 12:00:25

标签: python apache-spark pyspark apache-spark-sql

我有以下初始的PySpark DataFrame:

+----------+--------------------------------+
|product_PK|                        products|
+----------+--------------------------------+
|      686 |          [[686,520.70],[645,2]]|
|      685 |[[685,45.556],[678,23],[655,21]]|
|      693 |                              []|
df = sqlCtx.createDataFrame(
    [(686, [[686,520.70], [645,2]]), (685, [[685,45.556], [678,23],[655,21]]), (693, [])],
    ["product_PK", "products"]
)

products包含嵌套数据。我需要在每对值中提取第二个值。我正在运行此代码:

temp_dataframe = dataframe.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem("_2"))

适用于特定的DataFrame。但是,我想将此代码放入函数中并在不同的DataFrame上运行它。我的所有DataFrame都具有相同的结构。唯一的区别是子列"_2"在某些DataFrame中的名称可能不同,例如"col1""col2"

例如:

DataFrame content
root
 |-- product_PK: long (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: long (nullable = true)
 |    |    |-- _2: double (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- _1: long (nullable = true)
 |    |-- _2: double (nullable = true)
DataFrame content
root
 |-- product_PK: long (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- product_PK: long (nullable = true)
 |    |    |-- col2: integer (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- product_PK: long (nullable = true)
 |    |-- col2: integer (nullable = true)

我尝试使用getItem(1)这样的索引,但它说必须提供列的名称。

有没有办法避免指定列名或以某种方式概括代码的这一部分?

我的目标是exploded包含嵌套数据中每对的第二个值,即_2col1col2

3 个答案:

答案 0 :(得分:2)

听起来你走在正确的轨道上。我认为实现此目的的方法是阅读模式以确定要爆炸的字段的名称。但是,您需要使用schema.fields来查找struct字段,然后使用它的属性来计算结构中的字段,而不是schema.names。这是一个例子:

from pyspark.sql.functions import *
from pyspark.sql.types import *

# Setup the test dataframe
data = [
    (686, [(686, 520.70), (645, 2.)]), 
    (685, [(685, 45.556), (678, 23.), (655, 21.)]), 
    (693, [])
]

schema = StructType([
    StructField("product_PK", StringType()),
    StructField("products", 
        ArrayType(StructType([
            StructField("_1", IntegerType()),
            StructField("col2", FloatType())
        ]))
    )
])

df = sqlCtx.createDataFrame(data, schema) 

# Find the products field in the schema, then find the name of the 2nd field
productsField = next(f for f in df.schema.fields if f.name == 'products')
target_field = productsField.dataType.elementType.names[1]

# Do your explode using the field name
temp_dataframe = df.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem(target_field))

现在,如果你检查结果,你会得到这个:

>>> temp_dataframe.printSchema()
root
 |-- product_PK: string (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: integer (nullable = true)
 |    |    |-- col2: float (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- _1: integer (nullable = true)
 |    |-- col2: float (nullable = true)
 |-- score: float (nullable = true)

答案 1 :(得分:1)

这就是你想要的吗?

>>> df.show(10, False)
+----------+-----------------------------------------------------------------------+
|product_PK|products                                                               |
+----------+-----------------------------------------------------------------------+
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|
|693       |[]                                                                     |
+----------+-----------------------------------------------------------------------+

>>> import pyspark.sql.functions as F
>>> df.withColumn("exploded", F.explode("products")) \
...   .withColumn("exploded", F.col("exploded").getItem(1)) \
...   .show(10,False)
+----------+-----------------------------------------------------------------------+--------+
|product_PK|products                                                               |exploded|
+----------+-----------------------------------------------------------------------+--------+
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |null    |
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |2       |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|null    |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|23      |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|21      |
+----------+-----------------------------------------------------------------------+--------+

答案 2 :(得分:0)

鉴于您的exploded列是struct

 |-- exploded: struct (nullable = true)
 |    |-- _1: integer (nullable = true)
 |    |-- col2: float (nullable = true)

您可以使用以下逻辑来获取第二个元素而不知道名称

from pyspark.sql import functions as F
temp_dataframe = df.withColumn("exploded" , F.explode(F.col("products")))
temp_dataframe.withColumn("score", F.col("exploded."+temp_dataframe.select(F.col("exploded.*")).columns[1]))

你应该输出

+----------+--------------------------------------+------------+------+
|product_PK|products                              |exploded    |score |
+----------+--------------------------------------+------------+------+
|686       |[[686,520.7], [645,2.0]]              |[686,520.7] |520.7 |
|686       |[[686,520.7], [645,2.0]]              |[645,2.0]   |2.0   |
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[685,45.556]|45.556|
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[678,23.0]  |23.0  |
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[655,21.0]  |21.0  |
+----------+--------------------------------------+------------+------+