检查Scala中两个Spark DataFrame的相等性

时间:2016-11-08 22:38:35

标签: scala unit-testing apache-spark spark-dataframe

我是Scala的新手,在编写单元测试时遇到了问题。

我正在尝试比较和检查Scala中两个Spark DataFrames的单元测试的相等性,并意识到没有简单的方法来检查两个Spark DataFrame的相等性。

C ++等效代码将是(假设DataFrame在C ++中表示为双数组):

    int expected[10][2];
    int result[10][2];
    for (int row = 0; row < 10; row++) {
        for (int col = 0; col < 2; col++) {
            if (expected[row][col] != result[row][col]) return false;
        }
    }

实际测试将涉及根据DataFrames列的数据类型测试相等性(使用浮点数的精确容差进行测试等)。

似乎没有一种简单的方法可以使用Scala迭代循环遍历DataFrame中的所有元素,而其他用于检查两个DataFrame(例如df1.except(df2))相等的解决方案在我的情况下不起作用,因为我需要能够为浮动和双打的容差测试平等提供支持。

当然,我可以尝试事先对所有元素进行舍入并比较之后的结果,但我想看看是否有任何其他解决方案可以让我遍历DataFrames以检查是否相等。

3 个答案:

答案 0 :(得分:3)

import org.scalatest.{BeforeAndAfterAll, FeatureSpec, Matchers}

outDf.collect() should contain theSameElementsAs (dfComparable.collect())
# or ( obs order matters ! )
outDf.except(dfComparable).toDF().count should be(0) 

答案 1 :(得分:1)

如果要检查两个数据帧是否相等以用于测试目的,可以使用subtract()数据帧方法(版本1.3及更高版本支持)

您可以检查两个数据帧的差异是空还是0。 例如df1.subtract(df2).count() == 0

答案 2 :(得分:0)

假设你有一个固定的col和行#,一个解决方案可以通过行索引加入Df&#s;如果你没有记录的id&#39; s,然后直接迭代在最终的DF中[包含DF&amp; s的所有列]。 像这样:

Schemas
DF1
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)

DF2
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)

df1
+----------+-----------+------+
|      col1|       col2|  col3|
+----------+-----------+------+
|1.20000001|       1.21|   1.2|
|    2.1111|        2.3|  22.2|
|       3.2|2.330000001| 2.333|
|    2.2444|      2.344|2.3331|
+----------+-----------+------+

df2
+------+-----+------+
|  col1| col2|  col3|
+------+-----+------+
|   1.2| 1.21|   1.2|
|2.1111|  2.3|  22.2|
|   3.2| 2.33| 2.333|
|2.2444|2.344|2.3331|
+------+-----+------+

Added row index
df1
+----------+-----------+------+---+
|      col1|       col2|  col3|row|
+----------+-----------+------+---+
|1.20000001|       1.21|   1.2|  0|
|    2.1111|        2.3|  22.2|  1|
|       3.2|2.330000001| 2.333|  2|
|    2.2444|      2.344|2.3331|  3|
+----------+-----------+------+---+

df2
+------+-----+------+---+
|  col1| col2|  col3|row|
+------+-----+------+---+
|   1.2| 1.21|   1.2|  0|
|2.1111|  2.3|  22.2|  1|
|   3.2| 2.33| 2.333|  2|
|2.2444|2.344|2.3331|  3|
+------+-----+------+---+

Combined DF
+---+----------+-----------+------+------+-----+------+
|row|      col1|       col2|  col3|  col1| col2|  col3|
+---+----------+-----------+------+------+-----+------+
|  0|1.20000001|       1.21|   1.2|   1.2| 1.21|   1.2|
|  1|    2.1111|        2.3|  22.2|2.1111|  2.3|  22.2|
|  2|       3.2|2.330000001| 2.333|   3.2| 2.33| 2.333|
|  3|    2.2444|      2.344|2.3331|2.2444|2.344|2.3331|
+---+----------+-----------+------+------+-----+------+

这是你如何做到的:

println("Schemas")
    println("DF1")
    df1.printSchema()
    println("DF2")
    df2.printSchema()
    println("df1")
    df1.show
    println("df2")
    df2.show
    val finaldf1 = df1.withColumn("row", monotonically_increasing_id())
    val finaldf2 = df2.withColumn("row", monotonically_increasing_id())
    println("Added row index")
    println("df1")
    finaldf1.show()
    println("df2")
    finaldf2.show()

    val joinedDfs = finaldf1.join(finaldf2, "row")
    println("Combined DF")
    joinedDfs.show()

    val tolerance = 0.001
    def isInValidRange(a: Double, b: Double): Boolean ={
      Math.abs(a-b)<=tolerance
    }
    joinedDfs.take(10).foreach(row => {
      assert( isInValidRange(row.getDouble(1), row.getDouble(4)) , "Col1 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(2), row.getDouble(5)) , "Col2 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(3), row.getDouble(6)) , "Col3 validation. Row %s".format(row.getLong(0)+1))
    })
  

注意:断言不是序列化的,解决方法是使用take()来避免错误。