我正在使用Python的unittest
模块来测试我正在编写的脚本。
该脚本包含如下循环:
// my_script.py
def my_loopy_function():
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
使用unittest
测试complicated_function
我没有问题。但我想通过覆盖my_loopy_function
来测试complicated_function
。
我尝试修改我的脚本,以便my_loopy_function
将complicated_function
作为可选参数,以便我可以从测试中传入一个简单的版本:
// my_modified_script.py
def my_loopy_function(action_function=None):
if action_function is not None:
complicated_function = action_function
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
// test_my_script.py
from myscript import my_loopy_function
class TestMyScript(unittest.TestCase):
test_loopy_function(self):
def simple_function():
return 1
self.assertEqual(10, my_loopy_function(action_function=simple_function))
它没有像我希望的那样工作,有没有关于我应该怎么做的建议?
答案 0 :(得分:2)
最后我使用了Python mock
,这使我可以覆盖complicated_function
,而无需以任何方式调整原始代码。
以下是原始脚本,请注意complicated_function
未作为&{39; my_loopy_function
'传递给action_function
。参数(这是我在之前的解决方案中尝试过的):
// my_script.py
def my_loopy_function():
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
这是我用来测试它的脚本:
// test_my_script.py
import unittest
import mock
from my_script import my_loopy_function
class TestMyModule(unittest.TestCase):
@mock.patch('my_script.complicated_function')
def test_1(self, mocked):
mocked.return_value = 1
self.assertEqual(10, my_loopy_function())
这就像我想要的那样:
mock
模块给了我对内部的后编码访问。 感谢奥斯汀建议使用mock
。
顺便说一下,我使用的是Python 2.7,因此使用了来自PyPI的pip
- 可安装的mock
。
答案 1 :(得分:1)
请勿尝试使用complicated_function
覆盖action_function
,只需使用complicated_function
作为默认action_function
:
def my_loopy_function(action_function=complicated_function):
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + action_function(x)
return aggregate_value
答案 2 :(得分:0)
在您的代码中,您不应该像这样覆盖complicated_function
。如果我试一试,我会得到UnboundLocalError: local variable 'complicated_function' referenced before assignment
。
但也许问题是,在您的实际代码中,您以某种其他方式引用complicated_function
(例如作为模块的成员)?然后通过在测试中覆盖它,你将覆盖实际的complicated_function
,因此你将无法在其他测试中使用它。
执行此操作的正确方法是使用全局变量覆盖 local 变量,如下所示:
def my_loopy_function(action_function=None):
if action_function is None:
action_function = complicated_function
aggregate_value = 0
for x in range(10):
# Use action_function here instead of complicated_function
aggregate_value = aggregate_value + action_function(x)
return aggregate_value