我是Scala的新手,在编写单元测试时遇到了问题。
我正在尝试比较和检查Scala中两个Spark DataFrames的单元测试的相等性,并意识到没有简单的方法来检查两个Spark DataFrame的相等性。
C ++等效代码将是(假设DataFrame在C ++中表示为双数组):
int expected[10][2];
int result[10][2];
for (int row = 0; row < 10; row++) {
for (int col = 0; col < 2; col++) {
if (expected[row][col] != result[row][col]) return false;
}
}
实际测试将涉及根据DataFrames列的数据类型测试相等性(使用浮点数的精确容差进行测试等)。
似乎没有一种简单的方法可以使用Scala迭代循环遍历DataFrame中的所有元素,而其他用于检查两个DataFrame(例如df1.except(df2)
)相等的解决方案在我的情况下不起作用,因为我需要能够为浮动和双打的容差测试平等提供支持。
当然,我可以尝试事先对所有元素进行舍入并比较之后的结果,但我想看看是否有任何其他解决方案可以让我遍历DataFrames以检查是否相等。
答案 0 :(得分:3)
import org.scalatest.{BeforeAndAfterAll, FeatureSpec, Matchers}
outDf.collect() should contain theSameElementsAs (dfComparable.collect())
# or ( obs order matters ! )
outDf.except(dfComparable).toDF().count should be(0)
答案 1 :(得分:1)
如果要检查两个数据帧是否相等以用于测试目的,可以使用subtract()
数据帧方法(版本1.3及更高版本支持)
您可以检查两个数据帧的差异是空还是0。
例如df1.subtract(df2).count() == 0
答案 2 :(得分:0)
假设你有一个固定的col和行#,一个解决方案可以通过行索引加入Df&#s;如果你没有记录的id&#39; s,然后直接迭代在最终的DF中[包含DF&amp; s的所有列]。 像这样:
Schemas
DF1
root
|-- col1: double (nullable = true)
|-- col2: double (nullable = true)
|-- col3: double (nullable = true)
DF2
root
|-- col1: double (nullable = true)
|-- col2: double (nullable = true)
|-- col3: double (nullable = true)
df1
+----------+-----------+------+
| col1| col2| col3|
+----------+-----------+------+
|1.20000001| 1.21| 1.2|
| 2.1111| 2.3| 22.2|
| 3.2|2.330000001| 2.333|
| 2.2444| 2.344|2.3331|
+----------+-----------+------+
df2
+------+-----+------+
| col1| col2| col3|
+------+-----+------+
| 1.2| 1.21| 1.2|
|2.1111| 2.3| 22.2|
| 3.2| 2.33| 2.333|
|2.2444|2.344|2.3331|
+------+-----+------+
Added row index
df1
+----------+-----------+------+---+
| col1| col2| col3|row|
+----------+-----------+------+---+
|1.20000001| 1.21| 1.2| 0|
| 2.1111| 2.3| 22.2| 1|
| 3.2|2.330000001| 2.333| 2|
| 2.2444| 2.344|2.3331| 3|
+----------+-----------+------+---+
df2
+------+-----+------+---+
| col1| col2| col3|row|
+------+-----+------+---+
| 1.2| 1.21| 1.2| 0|
|2.1111| 2.3| 22.2| 1|
| 3.2| 2.33| 2.333| 2|
|2.2444|2.344|2.3331| 3|
+------+-----+------+---+
Combined DF
+---+----------+-----------+------+------+-----+------+
|row| col1| col2| col3| col1| col2| col3|
+---+----------+-----------+------+------+-----+------+
| 0|1.20000001| 1.21| 1.2| 1.2| 1.21| 1.2|
| 1| 2.1111| 2.3| 22.2|2.1111| 2.3| 22.2|
| 2| 3.2|2.330000001| 2.333| 3.2| 2.33| 2.333|
| 3| 2.2444| 2.344|2.3331|2.2444|2.344|2.3331|
+---+----------+-----------+------+------+-----+------+
这是你如何做到的:
println("Schemas")
println("DF1")
df1.printSchema()
println("DF2")
df2.printSchema()
println("df1")
df1.show
println("df2")
df2.show
val finaldf1 = df1.withColumn("row", monotonically_increasing_id())
val finaldf2 = df2.withColumn("row", monotonically_increasing_id())
println("Added row index")
println("df1")
finaldf1.show()
println("df2")
finaldf2.show()
val joinedDfs = finaldf1.join(finaldf2, "row")
println("Combined DF")
joinedDfs.show()
val tolerance = 0.001
def isInValidRange(a: Double, b: Double): Boolean ={
Math.abs(a-b)<=tolerance
}
joinedDfs.take(10).foreach(row => {
assert( isInValidRange(row.getDouble(1), row.getDouble(4)) , "Col1 validation. Row %s".format(row.getLong(0)+1))
assert( isInValidRange(row.getDouble(2), row.getDouble(5)) , "Col2 validation. Row %s".format(row.getLong(0)+1))
assert( isInValidRange(row.getDouble(3), row.getDouble(6)) , "Col3 validation. Row %s".format(row.getLong(0)+1))
})
注意:断言不是序列化的,解决方法是使用take()来避免错误。