fjhj2 fjhj2 -4 years ago 40
Python Question

How to check output from a method in the child class has correct size

Is there a pythonic way of checking the outputs of a method in a child class from the parent class.

For example if you have this structure, I want to perform a check (in the parent class) on the output from the method 'generate' in the child class (e.g. if it has the correct shape).

class parent_class(object):

def generate(self):
raise NotImplementedError

class child_class(parent_class):

def generate(self, array_size):
return np.random.uniform(size = [10,10])


The following achieves the correct effect, but requires calling the method check_class in the init method of the child class. Is there any way of achieving this check without having to remember to put the call to the 'check_class' method in every child class?

class parent_class(object):

def generate(self, size):
raise NotImplementedError

def check_class(self):
assert self.generate([5,5]).shape == (5,5), 'Output of generate has the wrong shape'

class child_class(parent_class):

def __init__(self):
self.check_class()

def generate(self, size):
return np.random.uniform(size = [10,10])


If you call generate, this will now check the size of the output is correct, as required:

a = child_class()

---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-58-389a5f325aca> in <module>()
----> 1 a = child_class()

<ipython-input-56-99054bf26f0e> in __init__(self)
10
11 def __init__(self):
---> 12 self.check_class()
13
14 def generate(self, size):

<ipython-input-56-99054bf26f0e> in check_class(self)
5
6 def check_class(self):
----> 7 assert self.generate([5,5]).shape == (5,5), 'Output of generate has the wrong shape'
8
9 class child_class(parent_class):

AssertionError: Output of generate has the wrong shape

Answer Source

You can set the check at instance creation level (__new__()) so that it always executes, e.g.:

class Parent(object):

    def __new__(cls, *args, **kwargs):
        instance = super(Parent, cls).__new__(cls)  # create our instance
        instance.check()  # immediately call check on it
        return instance  # return it to the requestor

    def check(self):
        assert self.generate() == 10, "generate() must return 10"

    def generate(self):
        raise NotImplementedError

class Child(Parent):

    def generate(self):
        return 10

class BadChild(Parent):

    def generate(self):
        return 20


child = Child()
# everything's fine...

parent = Parent()
# NotImplementedError

bad_child = BadChild()
# AssertionError: generate() must return 10

You can also use metaclasses, but don't go there if you don't need it.

As a side note, I hope you're not doing this thing as a sort of an API for your module, tho. Don't write Javaesque Python - treat the users of your module as adults and make it clear what your module expects.

Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download