如何使用Scala模拟Spark DataFrameReader?

时间:2019-04-03 07:42:54

标签: scala unit-testing apache-spark mocking

我想对使用sparkSession.read.jdbc(...)从RDBMS读取DataFrame的代码进行单元测试。但是我没有找到一种方法来模拟DataFrameReader以返回虚拟DataFrame进行测试。

代码示例:

object ConfigurationLoader {

def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
    spark.read
      .format("jdbc")
      .option("url", s"$postgresUrl/$postgresDatabase")
      .option("dbtable", tableName)
      .option("user", postgresUsername)
      .option("password", postgresPassword)
      .option("driver", postgresDriver)
      .load()
  }

def loadUsingFilter(dummyFilter: String*)(implicit spark: SparkSession): DataFrame = {
    readTable(postgresFilesTableName)
      .where(col("column").isin(fileTypes: _*))
  }
}

第二个问题-模拟scala对象,看来我需要使用其他方法来创建此类服务。

2 个答案:

答案 0 :(得分:3)

我认为,单元测试并非旨在测试数据库连接。这应该在集成测试中完成,以检查所有部分是否协同工作。单元测试仅用于测试您的功能逻辑,而不是激发从数据库读取数据的能力。

这就是为什么我会稍稍不同地设计您的代码,而不必关心数据库的原因。

/** This, I don't test. I trust spark.read */
def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
    spark.read
    .option(...)
    ...
    .load()
    // Nothing more
}

/** This I test, this is my logic.
def transform(df : DataFrame, dummyFilter: String*): DataFrame = {
    df
      .where(col("column").isin(fileTypes: _*))
}

然后我在生产中以这种方式使用代码。

val source = readTable("...")
val result = transform(source, filter)

现在包含我的逻辑的transform很容易测试。如果您想知道如何创建虚拟数据帧,我喜欢的一种方法是:

val df = Seq((1, Some("a"), true), (2, Some("b"), false), 
      (3, None, true)).toDF("x", "y", "z")
// and the test
val result = transform(df, filter)
result should be ...

答案 1 :(得分:1)

如果要测试sparkSession.read.jdbc(...),可以使用内存中的H2数据库。有时我在编写学习测试时会这样做。您可以在此处找到一个示例:https://github.com/bartosz25/spark-scala-playground/blob/d3cad26ff236ae78884bdeb300f2e59a616dc479/src/test/scala/com/waitingforcode/sql/LoadingDataTest.scala但是请注意,“真实” RDBMS可能会遇到一些细微的差异。

另一方面,您可以更好地分离代码的关注点,并以不同的方式创建DataFrame,例如使用toDF(...)方法。您可以在此处找到示例:https://github.com/bartosz25/spark-scala-playground/blob/77ea416d2493324ddd6f3f2be42122855596d238/src/test/scala/com/waitingforcode/sql/CorrelatedSubqueryTest.scala

最后和IMO,如果必须模拟DataFrameReader,则意味着也许与代码分离有关。例如,您可以将所有过滤器放在Filters对象中,并分别测试每个过滤器。映射或聚合功能相同。 2年前,我写了一篇有关测试Apache Spark的博客文章-https://www.waitingforcode.com/apache-spark/testing-spark-applications/read,它描述了RDD API,但是分离问题的想法是相同的。


已更新:

object Filters {
  def isInFileTypes(inputDataFrame: DataFrame, fileTypes: Seq[String]): DataFrame = {
    inputDataFrame.where(col("column").isin(fileTypes: _*))
  }
}

object ConfigurationLoader {

def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
    val input = spark.read
      .format("jdbc")
      .option("url", s"$postgresUrl/$postgresDatabase")
      .option("dbtable", tableName)
      .option("user", postgresUsername)
      .option("password", postgresPassword)
      .option("driver", postgresDriver)
      .load()
    Filters.isInFileTypes(input, Seq("txt", "doc")
  }

有了它,您就可以根据需要测试过滤逻辑:)如果您有更多的过滤器并想要对其进行测试,则还可以将它们组合为一个方法,并传递任何DataFrame所需的信息,然后添加: ) 除非有充分的理由,否则不应该测试.load()。这是Apache Spark内部逻辑,已经过测试。


更新,回答:

  

因此,现在我可以测试过滤器了,但是如何确保readTable确实使用了正确的过滤器(抱歉,为了更全面,这只是完整说明的问题)。也许您有一些简单的方法来模拟Scala对象(这实际上是第二个问题)。 – 14分钟前dytyniak

object MyApp {
  def main(args: Array[String]): Unit = {
    val inputDataFrame = readTable(postgreSQLConnection)
    val outputDataFrame = ProcessingLogic.generateOutputDataFrame(inputDataFrame)  
  }
}

object ProcessingLogic {
  def generateOutputDataFrame(inputDataFrame: DataFrame): DataFrame = {
    // Here you apply all needed filters, transformations & co
  }
}

如您所见,这里不需要模拟object。这似乎是多余的,但这不是因为您可以使用Filters对象来隔离地测试每个过滤器,并且可以使用ProcessingLogic对象(例如,仅用于命名)将所有处理逻辑组合在一起。您可以用任何有效的方式创建DataFrame。缺点是您将需要显式定义一个架构或使用case classes,因为在PostgreSQL源代码中,Apache Spark将自动解析该架构(我在这里进行了解释:https://www.waitingforcode.com/apache-spark-sql/schema-projection/read)。