如何检查子类中方法的输出具有正确的大小

时间:2017-06-13 22:50:55

标签: python

是否有一种pythonic方法可以检查父类中子类中方法的输出。

例如,如果你有这个结构,我想对子类中'generate'方法的输出执行检查(在父类中)(例如,如果它具有正确的形状)。

class parent_class(object):

    def generate(self):
        raise NotImplementedError

class child_class(parent_class):

    def generate(self, array_size):
        return np.random.uniform(size = [10,10])

以下实现了正确的效果,但需要在子类的 init 方法中调用方法check_class。有没有办法实现这个检查,而不必记得在每个子类中调用'check_class'方法?

class parent_class(object):

    def generate(self, size):
        raise NotImplementedError

    def check_class(self):
        assert self.generate([5,5]).shape == (5,5), 'Output of generate has the wrong shape'

class child_class(parent_class):

    def __init__(self):
        self.check_class()

    def generate(self, size):
        return np.random.uniform(size = [10,10])

如果您调用generate,现在将根据需要检查输出的大小是否正确:

a = child_class()

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call    last)
<ipython-input-58-389a5f325aca> in <module>()
----> 1 a = child_class()

<ipython-input-56-99054bf26f0e> in __init__(self)
 10 
 11     def __init__(self):
---> 12         self.check_class()
 13 
 14     def generate(self, size):

<ipython-input-56-99054bf26f0e> in check_class(self)
  5 
  6     def check_class(self):
----> 7         assert self.generate([5,5]).shape == (5,5), 'Output of generate has the wrong shape'
  8 
  9 class child_class(parent_class):

AssertionError: Output of generate has the wrong shape

1 个答案:

答案 0 :(得分:0)

您可以在实例创建级别(__new__())设置检查,以便始终执行,例如:

class Parent(object):

    def __new__(cls, *args, **kwargs):
        instance = super(Parent, cls).__new__(cls)  # create our instance
        instance.check()  # immediately call check on it
        return instance  # return it to the requestor

    def check(self):
        assert self.generate() == 10, "generate() must return 10"

    def generate(self):
        raise NotImplementedError

class Child(Parent):

    def generate(self):
        return 10

class BadChild(Parent):

    def generate(self):
        return 20


child = Child()
# everything's fine...

parent = Parent()
# NotImplementedError

bad_child = BadChild()
# AssertionError: generate() must return 10

您也可以使用元类,但如果您不需要,则不要去那里。

作为旁注,我希望你不要把这个东西作为你的模块的一种API。不要编写Javaesque Python - 将模块的用户视为成人,并明确模块所期望的内容。