Spark-如何在DataFrame中联接当前和以前的记录,以及如何为所有此类事件分配一个原始字段

时间:2018-12-31 09:45:57

标签: apache-spark pyspark pyspark-sql

我需要浏览一个Hive表,并将序列中第一条记录中的值添加到所有链接的记录中。

逻辑将是:-

  1. 查找第一条记录(其中previous_id为空白)。
  2. 查找下一条记录(current_id = previous_id)。
  3. 重复直到没有更多的链接记录。
  4. 将原始记录中的列添加到所有链接的记录中。
  5. 将结果输出到Hive表中。

示例源数据:-

current_id previous_id start_date
---------- ----------- ----------
100                    01/01/2001
200        100         02/02/2002
300        200         03/03/2003

示例输出数据:-

current_id start_date
---------- ----------
100        01/01/2001
200        01/01/2001
300        01/01/2001

我可以通过从源表创建两个DataFrame并执行多个联接来实现此目的。但是,这种方法似乎并不理想,因为必须缓存数据以避免每次迭代都重新查询源数据。

关于如何解决此问题的任何建议?

3 个答案:

答案 0 :(得分:1)

我认为您可以使用GraphFrames Connected components

完成此操作

它将帮助您避免自己编写检查点和循环逻辑。本质上,您是从current_idprevious_id对创建图形的,并使用GraphFrames来为每个顶点使用组件。然后,可以将得到的DataFrame与原始DataFrame合并以获得start_date

from graphframes import *

sc.setCheckpointDir("/tmp/chk")

input = spark.createDataFrame([
  (100, None, "2001-01-01"),
  (200, 100, "2002-02-02"),
  (300, 200, "2003-03-03"),
  (400, None, "2004-04-04"),
  (500, 400, "2005-05-05"),
  (600, 500, "2006-06-06"),
  (700, 300, "2007-07-07")
], ["current_id", "previous_id", "start_date"])

input.show()

vertices = input.select(input.current_id.alias("id"))

edges = input.select(input.current_id.alias("src"), input.previous_id.alias("dst"))

graph = GraphFrame(vertices, edges)

result = graph.connectedComponents()

result.join(input.previous_id.isNull(), result.component == input.current_id)\
  .select(result.id.alias("current_id"), input.start_date)\
  .orderBy("current_id")\
  .show()

结果如下:

+----------+----------+
|current_id|start_date|
+----------+----------+
|       100|2001-01-01|
|       200|2001-01-01|
|       300|2001-01-01|
|       400|2004-04-04|
|       500|2004-04-04|
|       600|2004-04-04|
|       700|2001-01-01|
+----------+----------+

答案 1 :(得分:0)

这是我不确定与Spark配合的方法。

缺少用于数据的分组ID /密钥。

不确定Catalyst将如何优化它-将在稍后的时间讨论。内存错误是否太大?

使数据变得更加复杂,并且确实可行。去吧:

# No grouping key evident, more a linked list with asc current_ids.
# Added more complexity to the example.
# Questions open on performance at scale. Interested to see how well Catalyst handles this.
# Need really some grouping id/key in the data.

from pyspark.sql import functions as f
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

# Started from dataframe.
# Some more realistic data? At least more complex.
columns = ['current_id', 'previous_id', 'start_date']
vals = [
        (100, None, '2001/01/01'),
        (200, 100,  '2002/02/02'),
        (300, 200,  '2003/03/03'),
        (400, None, '2005/01/01'),
        (500, 400,  '2006/02/02'),
        (600, 300,  '2007/02/02'),
        (700, 600,  '2008/02/02'),
        (800, None, '2009/02/02'),
        (900, 800,  '2010/02/02')  
       ]
df = spark.createDataFrame(vals, columns)
df.createOrReplaceTempView("trans")

# Starting data. The null / None entries. 
df2 = spark.sql("""
                   select * 
                     from trans 
                    where previous_id is null
                """)
df2.cache
df2.createOrReplaceTempView("trans_0")

# Loop through the stuff based on traversing the list elements until exhaustion of data, and, write to dynamically named TempViews.
# May need to checkpoint? Depends on depth of chain of linked items.
# Spark not well suited to this type of processing.  
dfX_cnt  = 1
cnt = 1

while (dfX_cnt != 0): 
  tabname_prev = 'trans_' + str(cnt-1)
  tabname = 'trans_' + str(cnt) 

  query = "select t2.current_id, t2.previous_id, t1.start_date from {} t1, trans t2 where t1.current_id = t2.previous_id".format(tabname_prev)
  dfX = spark.sql(query)
  dfX.cache

  dfX_cnt = dfX.count()
  if (dfX_cnt!=0):
      #print('Looping for dynamic creation of TempViews')
      dfX.createOrReplaceTempView(tabname)
      cnt=cnt+1

# Reduce the TempViews all to one DF. Can reduce an array of DF's as well, but could not find my notes here in this regard.
# Will memory errors occur? 

from pyspark.sql.types import *
fields = [StructField('current_id', LongType(), False),
          StructField('previous_id', LongType(), True),
          StructField('start_date',  StringType(), False)]
schema = StructType(fields)
dfZ = spark.createDataFrame(sc.emptyRDD(), schema)

for i in range(0,cnt,1):
    tabname = 'trans_' + str(i)
    query = "select * from {}".format(tabname)
    df = spark.sql(query)
    dfZ = dfZ.union(df)

# Show final results.
dfZ.select('current_id', 'start_date').sort(col('current_id')).show()

返回:

+----------+----------+
|current_id|start_date|
+----------+----------+
|       100|2001/01/01|
|       200|2001/01/01|
|       300|2001/01/01|
|       400|2005/01/01|
|       500|2005/01/01|
|       600|2001/01/01|
|       700|2001/01/01|
|       800|2009/02/02|
|       900|2009/02/02|
+----------+----------+

答案 2 :(得分:0)

感谢您在此处发布的建议。在尝试了各种方法之后,我采用了以下解决方案,该解决方案适用于多次迭代(例如20个循环),并且不会引起任何内存问题。

“物理计划”仍然很大,但是缓存意味着跳过了大多数步骤,从而保持了性能。

input = spark.createDataFrame([
    (100, None, '2001/01/01'),
    (200, 100,  '2002/02/02'),
    (300, 200,  '2003/03/03'),
    (400, None, '2005/01/01'),
    (500, 400,  '2006/02/02'),
    (600, 300,  '2007/02/02'),
    (700, 600,  '2008/02/02'),
    (800, None, '2009/02/02'),
    (900, 800,  '2010/02/02')
], ["current_id", "previous_id", "start_date"])

input.createOrReplaceTempView("input")

cur = spark.sql("select * from input where previous_id is null")
nxt = spark.sql("select * from input where previous_id is not null")

cur.cache()
nxt.cache()

cur.createOrReplaceTempView("cur0")
nxt.createOrReplaceTempView("nxt")

i = 1
while True:
    spark.sql("set table_name=cur" + str(i - 1))
    cur = spark.sql(
        """
            SELECT  nxt.current_id  as current_id,  
                    nxt.previous_id as previous_id, 
                    cur.start_date  as start_date       
            FROM    ${table_name}   cur, 
                    nxt             nxt 
            WHERE   cur.current_id = nxt.previous_id 
        """).cache()
    cur.createOrReplaceTempView("cur" + str(i))
    i = i + 1
    if cur.count() == 0:
        break

for x in range(0, i):
    spark.sql("set table_name=cur" + str(x))
    cur = spark.sql("select * from ${table_name}")
    if x == 0:
        out = cur
    else:
        out = out.union(cur)