在模拟中修补生成器

时间:2018-01-31 15:42:17

标签: python unit-testing mocking

我正在尝试为我的单元测试创​​建一个生成器方法的模拟版本。设置返回值很简单:

    patcher = unittest.mock.patch.multiple("__main__.MyClass",
        method_one=DEFAULT,
        ...
    )
    patcher.start()
    MyClass.method_one.return_value = my_mock_gen(some_params)

然而,这只会工作一次,每次都必须重置:

    for value in my_obj.method_one(some_params):
        # do stuff with value

    my_obj.method_one.reset_mock() # doesn't work with `yield from` as well

另外,我无法根据some_params生成模拟值。是否有一种处理这种情况的首选方法,我错过了?

2 个答案:

答案 0 :(得分:3)

如果你想在调用模拟时运行一个函数,那就是side_effect,而不是return_value

MyClass.method_one.side_effect = my_mock_gen

答案 1 :(得分:0)

解决方案是将generator转换为list进行测试。它有效,但可能不好。

    #!/usr/bin/python
    import mock
    import unittest


    class FooBar(object):

        def method_one(self, a, b):
            for i in range(a, b):
                yield i * i

        def bar(self):
            return 'bar'


    def mock_generator(a, b):
        for i in range(a, b):
            yield (i + 1) * i


    class TestFooBar(unittest.TestCase):

        def test_method_one(self):
            # set default to a list
            mock_method_one = mock.Mock(return_value=['aaa', 'ccc'])
            # this is how I access the class
            patcher = mock.patch.multiple(
                "python_tools.tests.patch_multiple.FooBar",
                method_one=mock_method_one,
                bar=mock.Mock(return_value='bar')
            )

            patcher.start()
            # method_one return generator only supposed to be read once
            # convert generator to list for testing, crazy?! not sure
            # [6, 12] after converting to list
            mock_method_one.return_value = list(mock_generator(2, 4))
            fb = FooBar()

            # no change
            self.assertEqual(fb.method_one(200, 500), [6, 12])

            # no change
            for item in fb.method_one(2, 5):
                self.assertTrue(item, [6, 12])

            # no change
            self.assertEqual(fb.method_one(1, 6), [6, 12])

            patcher.stop()