Python:从子类调用超类方法时超过最大递归深度

时间:2011-07-03 11:05:10

标签: python exception inheritance recursion

我有三个类的继承链:Baz继承自Bar的继承自Foo。我想在Foo类中定义一个方法,当在子类中调用时,返回其父类的输出,并附加自己的东西。这是所需的输出:

>>> foo = Foo()
>>> bar = Bar()
>>> baz = Baz()
>>> print foo.get_defining_fields()
['in Foo']
>>> print bar.get_defining_fields()
['in Foo', 'in Bar']
>>> print baz.get_defining_fields()
['in Foo', 'in Bar', 'in Baz']

问题是,我误解了在子类调用的超类方法中使用super()或其他一些子类化细节的问题。这个位工作正常:

>>> print foo.get_defining_fields()
['in Foo']

但是bar.get_defining_fields()会产生一个无限循环,自行运行直到RuntimeError被提升,而不是呼叫foo.get_defining_fields并停在那里,就像我想要的那样。

这是代码。

class Foo(object):
    def _defining_fields(self):
        return ['in Foo']

    def get_defining_fields(self):
        if self.__class__ == Foo:
            # Top of the chain, don't call super()
            return self._defining_fields()
        return super(self.__class__, self).get_defining_fields() + self._defining_fields()

class Bar(Foo):
    def _defining_fields(self):
        return ['in Bar']

class Baz(Bar):
    def _defining_fields(self):
        return ['in Baz']

因此get_defining_fields在超类中定义,并且其中的super()调用使用self.__class__来尝试将正确的子类名称传递给每个子类中的调用。在Bar中调用时,它会解析为super(Bar, self).get_defining_fields(),以便foo.get_defining_fields()返回的列表将添加到far.get_defining_fields()返回的列表之前。

如果你正确理解了Python的继承机制和super()的内部运作,可能是一个简单的错误,但是因为我显然不这样做,如果有人能指出我这样做的正确方法,我会很感激


编辑:根据Daniel Roseman的回答,我尝试用此表单替换super()来电:return super(Foo, self).get_defining_fields() + self._defining_fields()

现在无限递归不再发生,但在调用bar.get_defining_fields()时出现了不同的错误:

AttributeError: 'super' object has no attribute 'get_defining_fields'

其他一些东西仍然不对。


编辑:是的,终于找到了我在这里失踪的东西。将丹尼尔的最新答案标记为已接受的答案。

1 个答案:

答案 0 :(得分:6)

问题在于:

 return super(self.__class__, self)...

self.__class__始终引用当前具体类。所以在Bar中,它指的是Bar,而在Baz中,它指的是Baz。因此,Bar调用超类的方法,但self.__class__仍然引用Bar,而不是Foo。所以你会得到无休止的递归。

这就是为什么你必须总是在super中引用。像这样:

return super(Foo, self)...

修改是的,因为只有顶级类定义了get_defining_fields

你真的是错误的做法。这根本不是super的工作。我认为你可以通过self.__class__.__mro__迭代更好的结果,这是访问超类的方法解析顺序的方法:

class Foo(object):
    def _defining_fields(self):
        return ['in Foo']

    def get_defining_fields(self):
        fields = []
        for cls in self.__class__.__mro__:
            if hasattr(cls, '_defining_fields'):
                fields.append(cls._defining_fields(self))
        return fields

¬