Python-无法模拟对继承的类的调用

时间:2018-09-18 15:01:18

标签: python python-3.x python-unittest

我有这个主班

def main(args):
    if type == train_pipeline_type:
        strategy = TrainPipelineStrategy()
    else:
        strategy = TestPipelineStrategy()
    for table in fetch_table_information_by_region(region):
        split_required = DataUtils.load_from_dict(table, "split_required")
        if split_required:
            strategy.split(spark=spark, table_name=table_name,
                           data_loc=filtered_data_location, partition_column=partition_column,
                           split_output_dir= split_output_dir)
            logger.info("Data Split for table : {} completed".format(table_name))

我的TrainPipelineStrategy和TestPipelineStrategy看起来像这样-

class PipelineTypeStrategy(object):

    def partition_data(self, x):
        # Something

    def prepare_split_data(self, y):
        # Something

    def write_split_data(self, z):
        # Something

    def split(self, p):
        # Something


class TrainPipelineStrategy(PipelineTypeStrategy):
    """"""


class TestPipelineStrategy(PipelineTypeStrategy):

    def write_split_data(self, y):
        # Something else

我的测试用例- 我需要通过在main方法中模拟split功能来测试调用split的次数。

这是我尝试过的-

@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
    def test_split_data_main_split_data_call_count(self, fake_train):
        fake_train_functions = mock.Mock()
        fake_train_functions.split.return_value = None
        fake_train.return_value = fake_train_functions
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        assert fake_train_functions.split.call_count == 10

当我尝试运行测试时,它会创建模拟,但最终最终会调用实际的split函数。我在做什么错了?

1 个答案:

答案 0 :(得分:0)

此代码的主要问题是,设置patch的方式将是如果TrainPipelineStrategyPipelineTypeStrategy的嵌套类,而TrainPipelineStrategy是一个子类的PipelineTypeStrategy

由于TrainPipelineStrategy继承自PipelineTypeStrategy,因此可以直接访问split,因此您可以修补split而无需引用PipelineTypeStrategy(除非您明确想要修补split中定义的PipelineTypeStrategy的版本。

但是,如果您只想模拟split类的PipelineTypeStrategy方法,则应该使用patch.object装饰器来模拟split而不是模拟整个模型类,因为它更干净。这是一个示例:

class TestClass(unittest.TestCase):
    @patch.object(TrainPipelineStrategy, 'split', return_value=None)
    def test_split_data_main_split_data_call_count(self, mock_split):
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        self.assertEqual(mock_split.call_count, 10)