我有这个主班
def main(args):
if type == train_pipeline_type:
strategy = TrainPipelineStrategy()
else:
strategy = TestPipelineStrategy()
for table in fetch_table_information_by_region(region):
split_required = DataUtils.load_from_dict(table, "split_required")
if split_required:
strategy.split(spark=spark, table_name=table_name,
data_loc=filtered_data_location, partition_column=partition_column,
split_output_dir= split_output_dir)
logger.info("Data Split for table : {} completed".format(table_name))
我的TrainPipelineStrategy和TestPipelineStrategy看起来像这样-
class PipelineTypeStrategy(object):
def partition_data(self, x):
# Something
def prepare_split_data(self, y):
# Something
def write_split_data(self, z):
# Something
def split(self, p):
# Something
class TrainPipelineStrategy(PipelineTypeStrategy):
""""""
class TestPipelineStrategy(PipelineTypeStrategy):
def write_split_data(self, y):
# Something else
我的测试用例- 我需要通过在main方法中模拟split功能来测试调用split的次数。
这是我尝试过的-
@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
def test_split_data_main_split_data_call_count(self, fake_train):
fake_train_functions = mock.Mock()
fake_train_functions.split.return_value = None
fake_train.return_value = fake_train_functions
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
assert fake_train_functions.split.call_count == 10
当我尝试运行测试时,它会创建模拟,但最终最终会调用实际的split函数。我在做什么错了?
答案 0 :(得分:0)
此代码的主要问题是,设置patch
的方式将是如果TrainPipelineStrategy
是PipelineTypeStrategy
的嵌套类,而TrainPipelineStrategy
是一个子类的PipelineTypeStrategy
。
由于TrainPipelineStrategy
继承自PipelineTypeStrategy
,因此可以直接访问split
,因此您可以修补split
而无需引用PipelineTypeStrategy
(除非您明确想要修补split
中定义的PipelineTypeStrategy
的版本。
但是,如果您只想模拟split
类的PipelineTypeStrategy
方法,则应该使用patch.object
装饰器来模拟split
而不是模拟整个模型类,因为它更干净。这是一个示例:
class TestClass(unittest.TestCase):
@patch.object(TrainPipelineStrategy, 'split', return_value=None)
def test_split_data_main_split_data_call_count(self, mock_split):
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
self.assertEqual(mock_split.call_count, 10)