是否有一种方法可以模拟在另一个函数调用中调用的函数返回?例如:
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
# hijack bar() here to return "bat" instead
assert(foo() == "bat")
我尝试使用@mock.patch
,但是发现它只允许我模拟正在调用的函数,而不能模拟由于调用其他函数而被调用的函数。
答案 0 :(得分:2)
您可以编写上下文管理器来临时交换全局名称空间中的对象:
class Hijack:
def __init__(self, name, replacement, namespace):
self.name = name
self.replacement = replacement
self.namespace = namespace
def __enter__(self):
self.original = self.namespace[self.name]
self.namespace[self.name] = self.replacement
def __exit__(self, *args):
self.namespace[self.name] = self.original
您可以使用劫持方法调用模拟函数:
def bar():
return "baz"
def bar_mock():
return "bat"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
with Hijack('bar', bar_mock, globals()):
assert(foo() == "bat")
这是一种非常通用的方法,可以在单元测试之外使用。实际上,将其推广到可用于任何可表示为某种映射的可变对象的工作非常简单:
class Hijack:
def __init__(self, name, replacement, namespace, getter=None, setter=None):
self.name = name
self.replacement = replacement
self.namespace = namespace
self.getter = type(namespace).__getitem__ if getter is None else getter
self.setter = type(namespace).__setitem__ if setter is None else setter
def __enter__(self):
self.original = self.getter(self.namespace, self.name)
self.setter(self.namespace, self.name, self.replacement)
def __exit__(self, *args):
self.setter(self.namespace, self.name, self.original)
对于类和其他对象,应使用getter=getattr
和setter=setattr
。对于None
优于KeyError
的情况,可以使用getter=dict.get
等。
答案 1 :(得分:2)
通用补丁
unittest.mock.patch
确实可以执行我的其他答案建议。您可以根据需要添加任意数量的@patch
注释,所选对象将被修补:
from unittest.mock import patch
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
@patch(__name__ + '.bar', lambda: 'bat')
def test_code(self):
assert(foo() == "bat")
在此配置中,bar
完成后将恢复功能test_code
。如果您希望同一补丁程序适用于您班级中的所有测试用例,请对整个班级进行注释:
@patch(__name__ + '.bar', lambda: 'bat')
class Tests(unittest.TestCase):
def test_code(self):
assert(foo() == "bat")
修补全局变量
您也可以在全局名称空间上unittest.mock.patch.dict
以获得相同的结果:
class Tests(unittest.TestCase):
@patch.dict(globals(), {'bar': lambda: 'bat'})
def test_code(self):
assert(foo() == "bat")
答案 2 :(得分:1)
此外,您可以使用seattr
固定装置的pytest monkeypatch
方法。对于第一个参数,它接受要修补的对象,否则字符串将被解释为点分导入路径,最后一部分是属性名称:
# foo_module.py
def bar():
return "baz"
def foo():
return bar()
# test_foo.py
from foo_module import foo
def test_foo(monkeypatch):
monkeypatch.setattr('foo_module.bar', lambda: 'bat')
assert foo() == "bat"