Initialize subclass instance when creating new Base class instance

1k Views Asked by At

Summary

TLDR: Is it possible to have calling/instantiating the base class actually return an initialized subclass instance?

Example

Consider this Animal base class and Cat and Dog subclasses:

from abc import ABC, abstractmethod

class Animal(ABC):

    @property
    @abstractmethod
    def weight(self) -> float:
        """weight of the animal in kg."""
        ...


class Dog(Animal):
    def __init__(self, weight: float = 5):
        if not (1 < weight < 90):
            raise ValueError("No dog has this weight")
        self._weight = weight

    weight: float = property(lambda self: self._weight)


class Cat(Animal):
    def __init__(self, weight: float = 5):
        if not (0.5 < weight < 15):
            raise ValueError("No cat has this weight")
        self._weight = weight

    weight: float = property(lambda self: self._weight)

This works as intended:

c1 = Cat(0.7)  # no problem
c2 = Cat(30)  # ValueError

Now, I want to extend this so that calling of the Animal class should return one of its subclasses, and namely the first that does not raise an error.

So, I want c3 = Animal(0.7) to return a Cat instance.

Attempt

I know how to return an instance from a subclass when instantiating the base class, but only if it can be determined before running __init__, which one it is.

So, this does not work...

class Animal(ABC):

    def __new__(cls, *args, **kwargs):
        if cls in cls._subclasses():
            return object.__new__(cls)

        for cls in [Dog, Cat]:  # prefer to return a dog
            try:
                return object.__new__(cls)
            except ValueError:
                pass

    @property
    @abstractmethod
    def weight(self) -> float:
        """weight of the animal in kg."""
        ...

...because the ValueError is only raised when the instance is already created and returned:

c3 = Animal(0.7) # ValueError ('no dog has this weight') instead of Cat instance.

Is there a way to achieve this?

Current workaround

This works but is detached from the classes and feels badly integrated / highly coupled.

def create_appropriate_animal(*args, **kwargs) -> Animal:
    for cls in [Dog, Cat]:
        try:
            return cls(*args, **kwargs)
        except ValueError:
            pass  
    raise("No fitting animal found.")

c4 = create_appropriate_animal(0.7)  # returns cat

EDIT:

  • Thanks @chepner for the __subclasses__() suggestion; I've integrated it into the question.
2

There are 2 best solutions below

1
Serge Ballesta On

This is indeed a weird requirement, but it can be met by customization of __new__:

class Animal(ABC):
    subclasses = []

    @property
    @abstractmethod
    def weight(self) -> float:
        """weight of the animal in kg."""
        ...

    def __new__(cls, *args, **kwargs):
        if cls == Animal:
            for cls in Animal.__subclasses__():
                try:
                    return cls(*args, **kwargs)
                except TypeError:
                    pass
        return super().__new__(cls)

You can now successfully write:

a = Animal(2)   # a will be Dog(2)
b = Animal(0.7) # b will be Cat(0.7)

BTW, if all subclasses raise an error, the last one will be used (and will raise its own error)

3
matszwecja On

@ElRudi: As requested, solution using Factory pattern. While it is probably possible to automatically detect all the classes inheriting from Animal to avoid having to declare the list of possible animals when creating factory, I don't think it's a good idea and I didn't bother with doing that.

class AnimalFactory:
    def __init__(self, animalTypes : list[Animal] = None):
        if animalTypes == None:
            self.animals = []
        else:
            self.animals = animalTypes
    def createAnimal(self, weight: float) -> Animal:
        for AnimalType in self.animals:
            try:
                return AnimalType(weight)
            except ValueError:
                pass
        raise ValueError("No animal has this weight")


af = AnimalFactory([Dog, Cat])

c1 = af.createAnimal(0.7)  # Cat
print(c1)
c2 = af.createAnimal(30)  # Dog
print(c2)
c3 = af.createAnimal(100) # ValueError