如何在pytest中进行自定义比较?

时间:2019-02-09 13:32:14

标签: python pytest

例如,我想断言两个Pyspark DataFrame具有相同的数据,但是仅使用==来检查它们是否是同一对象。理想情况下,我还要指定订单是否重要。

我尝试编写一个引发AssertionError的函数,但由于它显示了对该函数的回溯,因此会给pytest输出增加很多噪音。

我的另一种想法是模拟DataFrames的__eq__方法,但我不确定这是正确的方法。

编辑:

我考虑只使用返回true或false的函数而不是运算符,但是似乎不适用于pytest_assertrepr_compare。我对该钩子的工作方式还不太熟悉,因此可能有一种方法将其与函数一起使用,而不是与运算符一起使用。

4 个答案:

答案 0 :(得分:1)

我当前的解决方案是使用补丁覆盖DataFrame的__eq__方法。这是熊猫的一个示例,因为它可以更快地进行测试,所以该想法应适用于任何对象。

import pandas as pd
# use this import for python3
# from unittest.mock import patch
from mock import patch


def custom_df_compare(self, other):
    # Put logic for comparing df's here
    # Returning True for demonstration
    return True


@patch("pandas.DataFrame.__eq__", custom_df_compare)
def test_df_equal():
    df1 = pd.DataFrame(
        {"id": [1, 2, 3], "name": ["a", "b", "c"]}, columns=["id", "name"]
    )
    df2 = pd.DataFrame(
        {"id": [2, 3, 4], "name": ["b", "c", "d"]}, columns=["id", "name"]
    )

    assert df1 == df2

还没有尝试过,但是正计划将其添加为固定装置,并使用autouse自动将其用于所有测试。

为了优雅地处理“订单事项”指示符,我正在使用一种类似于pytest.approx的方法,例如,它返回一个具有自己的__eq__的新类:

class SortedDF(object):
    "Indicates that the order of data matters when comparing to another df"

    def __init__(self, df):
        self.df = df

    def __eq__(self, other):
        # Put logic for comparing df's including order of data here
        # Returning True for demonstration purposes
        return True


def test_sorted_df():
    df1 = pd.DataFrame(
        {"id": [1, 2, 3], "name": ["a", "b", "c"]}, columns=["id", "name"]
    )
    df2 = pd.DataFrame(
        {"id": [2, 3, 4], "name": ["b", "c", "d"]}, columns=["id", "name"]
    )

    # Passes because SortedDF.__eq__ is used
    assert SortedDF(df1) == df2
    # Fails because df2's __eq__ method is used
    assert df2 == SortedDF(df2)

我无法解决的次要问题是第二个断言assert df2 == SortedDF(df2)的失败。此订单可与pytest.approx配合使用,但不能在此使用。我曾尝试阅读==运算符,但无法弄清楚如何解决第二种情况。

答案 1 :(得分:0)

要在DataFrames的值之间进行原始比较(必须精确排序),可以执行以下操作:

import pandas as pd
from pyspark.sql import Row

df1 = spark.createDataFrame([Row(a=1, b=2, c=3), Row(a=1, b=3, c=3)])
df2 = spark.createDataFrame([Row(a=1, b=2, c=3), Row(a=1, b=3, c=3)])

pd.testing.assert_frame_equal(df1.toPandas(), df2.toPandas())

如果要按顺序指定,则可以对pandas DataFrame进行一些转换,以首先使用以下功能按特定列进行排序:

def assert_frame_equal_with_sort(results, expected, keycolumns):
  results = results.reindex(sorted(results.columns), axis=1)
  expected = expected.reindex(sorted(expected.columns), axis=1)

  results_sorted = results.sort_values(by=keycolumns).reset_index(drop=True)
  expected_sorted = expected.sort_values(by=keycolumns).reset_index(drop=True)

  pd.testing.assert_frame_equal(results_sorted, expected_sorted)


df1 = spark.createDataFrame([Row(a=1, b=2, c=3), Row(a=1, b=3, c=3)])
df2 = spark.createDataFrame([Row(a=1, b=3, c=3), Row(a=1, b=2, c=3)])

assert_frame_equal_with_sort(df1.toPandas(), df2.toPandas(), ['b'])

答案 2 :(得分:0)

您可以使用pytest钩子之一,特别是pytest_assertrepr_compare。您可以在其中定义您要比较的内容以及如何比较,文档也不错,并带有示例。祝你好运。 :)

答案 3 :(得分:0)

只需使用 pandas.Dataframe.equals 方法 https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.equals.html

例如

assert df1.equals(df2)

assert可以与任何返回布尔值的东西一起使用。所以可以,您可以编写任何自定义比较功能来比较两个对象。只要自定义函数返回一个布尔值。但是,在这种情况下,由于熊猫已经提供了自定义功能,因此不需要自定义功能