我有一个泛型类,用户应该子类化以实现某些方法。可以有几个级别的子类。像
这样的东西class Thing(object):
def fun(self, *args, **kwargs):
raise NotImplementedError()
class Bell(Thing):
def fun(self):
return 1
class Whistle(Bell):
def fun(self):
return super(Whistle, self).fun() + 1
我想计算在使用fun()
的任何子类时调用Thing
的次数。因为装饰器不是继承的,并且因为我不希望用户必须记住装饰他们的fun()
方法,所以我的理解是元类是要走的路。所以我写了
class CountCalls(type):
def __new__(cls, name, bases, attrs):
attrs["_original_fun"] = attrs["fun"]
attrs["fun"] = countcalls(attrs["_original_fun"])
return super(CountCalls, cls).__new__(cls, name, bases, attrs)
其中countcalls
是计算调用次数的经典装饰器:
def countcalls(fn):
def wrapper(*args, **kwargs):
wrapper.ncalls += 1
return fn(*args, **kwargs)
wrapper.ncalls = 0
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper
并将Thing
的定义更改为
class Thing(object):
__metaclass__ = CountCalls
def fun(self, *args, **kwargs):
raise NotImplementedError()
问题:这有效,但是当fun()方法增加所有实例的调用次数时,它会产生意想不到的副作用>调用任何实例:
>>> b1 = Bell()
>>> b2 = Bell()
>>> b1.fun.ncalls, b2.fun.ncalls
(0, 0)
>>> b1.fun()
1
>>> b1.fun.ncalls, b2.fun.ncalls
(1, 1)
问题:如何计算每个实例的fun()
来电次数?感觉我应该在元类中实现__init__
而不是__new__
,但到目前为止我还没有找到正确的语法。例如,使用
def __init__(self, name, bases, attrs):
attrs["_original_fun"] = attrs["fun"]
attrs["fun"] = countcalls(attrs["_original_fun"])
super(CountCalls, self).__init__(name, bases, attrs)
产量
>>> b = Bell()
>>> b.fun.ncalls
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute 'ncalls'
谢谢!
答案 0 :(得分:2)
您可以通过稍微更改继承模式来跳过元类:
class Thing(object):
def __init__(self):
self.fun_calls = 0
def fun(self, *args, **kwargs):
self.fun_calls += 1
self._fun(*args, **kwargs)
def _fun(self, *args, **kwargs):
raise NotImplementedError()
然后只需覆盖子类中的_fun
。这使得每个实例自动计数,并且它(imo)比元类实现更清晰,更易理解。
答案 1 :(得分:0)
为了跟踪每个实例而不是每个函数的调用,你需要变量跟踪它在实例上,如:
def countcalls(fn):
def wrapper(self,*args, **kwargs):
self._calls_dict[wrapper]+=1
return fn(self,*args, **kwargs)
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper
虽然您需要某种方法来初始化第一次调用的字典:
import collections
def countcalls(fn):
def wrapper(self,*args, **kwargs):
if not hasattr(self,"_calls_dict"):
self._calls_dict = collections.defaultdict(int)
self._calls_dict[wrapper.__name__]+=1
return fn(self,*args, **kwargs)
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper
虽然请注意,如果将此装饰器应用于classmethod
或staticmethod
,这将会崩溃,所以请小心如何实现此功能。
这也使得如何检索调用次数变得复杂,但是使用单独的函数可以相当容易地完成:
from types import MethodType
def get_calls(method):
if not isinstance(method,MethodType):
raise TypeError("must pass bound method as argument")
func = method.__func__
inst = method.__self__
dct = getattr(inst,"_calls_dict",None)
if dct is None or func not in dct:
return 0 #maybe raise an error instead?
else:
return dct[func]