我的PySpark代码中有一个单元测试模块,但不确定如何执行。
这是我的代码,它仅读取一个只有2列Day和Amount的数据框。脚本另存为test.py。我只是在“ Day”上做sum(amount)。这是代码-
import sys
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import *
import pytest
import unittest
def main():
spark=SparkSession.builder.appName("MyApp").config("spark.sql.shuffle.partitions","2").getOrCreate()
#Start ETL
data = extract_data(spark)
data_transformed = transform_data(data)
# log the success and terminate Spark application
spark.stop()
return None
# Extract Data
def extract_data(spark):
df = (spark.read.option("inferSchema", "true").option("header","true").csv("myfile.txt"))
return df
# Transform Data
def transform_data(df):
df_transformed = (df.groupBy("Day").sum("Amount").withColumnRenamed("sum(Amount)","total_amt").select("Day","total_amt"))
return df_transformed
pytestmark = pytest.mark.usefixtures("spark")
def my_test_func(self):
test_input = [Row(Day=1, Amount =10),\
Row(Day=1, Amount =20)]
input_df = spark.createDataFrame(test_input)
result = transform_data(input_df).select("total_amt").collect()[0]
expected_result = 30
self.assertEqual(result, expected_result)
print("test done")
if __name__ == '__main__':
main()
我是PySpark的新手,有几个问题-
答案 0 :(得分:2)
您可以在其他文件夹中编写Spark单元测试。 例如,
src
+--jobs
+-- job1
tests
+--__jobs
+---job1
然后下面是您编写测试用例的方式
class TestJob1 \
(unittest.TestCase):
def setUp(self):
"""
Start Spark, define config and path to test data
"""
self.spark=SparkSession.builder
.appName("MyApp")
.config("spark.sql.shuffle.partitions","2")
.getOrCreate()
self.job1 = Job1(self.spark)
def tearDown(self):
"""
Stop Spark
"""
self.spark.stop()
def test_yourtest_code(self):
test_input = [Row(Day=1, Amount =10),
Row(Day=1, Amount =20)]
input_df = spark.createDataFrame(test_input)
result = transform_data(input_df).select("total_amt").collect()[0]
expected_result = 30
self.assertEqual(result, expected_result)
print("test done")
您可以运行测试用例
python -m unittest jobs.TestJob1
python -m unittest jobs.TestJob1.test_yourtest_code