在pyspark中使用功能和单元测试

时间:2020-05-15 07:59:37

标签: apache-spark pyspark

希望你很好

我有一个挣扎的问题。

我看到我们可以使用pytest和其他编写pyspark单元测试。但是它们都测试了函数的输出。

在下面,我没有使用任何功能-我只想说:

如果我将这些数据放入,我希望该数据出来。会发生吗?

我该怎么做?我应该使用函数吗?我该怎么办?

谢谢!

#!/usr/bin/env python
from pyspark.sql import SparkSession
import pyspark.sql.functions as sqlfunc
import argparse, sys
from pyspark.sql import *
from pyspark.sql.functions import *
from datetime import datetime
from pyspark.sql.functions import lit
from pyspark.sql.types import *
from pyspark.sql.window import Window


now = datetime.now()
yday  = long(now.strftime('%s')) - 1*24*60*60
yday_date = datetime.fromtimestamp(yday).strftime('%Y%m%d')



radiusday  = long(now.strftime('%s')) - 3*24*60*60
radius_date = datetime.fromtimestamp(radiusday).strftime('%Y%m%d')

#create a context that supports hive
def create_session(appname):
    spark_session = SparkSession\
        .builder\
        .appName(appname)\
        .master('yarn')\
        .enableHiveSupport()\
        .getOrCreate()
    return spark_session

### START MAIN ###
if __name__ == '__main__':
    '''
    TABLE DEFINITION AND CACHING
    '''
    print(datetime.now())
    spark_session = create_session('ios12')
    ipfr_logs = spark_session.sql("Select * from db.table2 where dt = " + yday_date )
    web_logs = spark_session.sql("Select * from db.table1 where dt = " + yday_date )

    #extract IPFR data for 5223
    df = ipfr_logs.coalesce(1000).filter((ipfr_logs.serverport == '5223'))\
    .select('emsisdn', 'imsi', 'transactiontime', 'apnid','imeisv', 'unixtimestamp')

    #extract weblog data for 5223
    webdf = web_logs.coalesce(1000).filter((web_logs.serverport == '5223'))\
    .select('emsisdn', 'imsi', 'transactiontime', 'apnid','imeisv', 'timestamp')

    #union both dataframes
    df = df.union(webdf)

    #take the first 8 characters of IMEISV to get TAC
    df2 = df.withColumn('tac_ipfr', df['imeisv'].substr(1, 8))

    #configure windowing so we can order DF
    windowed_view = Window.partitionBy('emsisdn').orderBy('unixtimestamp')

    #pull the timestamp of the next transaction up onto this line
    df3 = df2.withColumn('next_trans', sqlfunc.lead(df.unixtimestamp).over(windowed_view))

    #calculate the perceived end of the session (unixtimestamp + transaction time)
    df3 = df3.withColumn('perceived_end', df3.unixtimestamp + df3.transactiontime)

    #IF the perceived end is greater than the start time of the next transaction, use the start of the next transaction time
    df3 = df3.withColumn('real_end', when(df3.perceived_end > df3.next_trans, df3.next_trans).otherwise(df3.perceived_end))

    #Now we know the ACTUAL end time, we need to calculate transactiontime
    df3 = df3.withColumn('new_trans_time', df3.real_end - df3.unixtimestamp)

    #write to Hive
    df3.createOrReplaceTempView('tt')
    finaldf = spark_session.sql("insert overwrite table keenek1.test_lead select * from tt")

0 个答案:

没有答案