Python mypy typing inference involving unions and lists

51 Views Asked by At

Hi all I'm a bit confused about mypy's behavior involving unions and list. here is a simplified version to help explain:

from typing import Union


class A:
    pass

class B:
    pass


def f(items, a: Union[type[A], type[B]]) -> Union[list[A], list[B]]:
    return [a() for x in items] # Incompatible return value type (got "list[Union[A, B]]", expected "Union[list[A], list[B]]")

I expect the list comprehension to create a homogenous list of either A or B. Mypy infers that the list is heterogenous and can contain items of both type A and B at the same time. I don't see why it should since the parameter a is either A or B but does not change during the list creation.

Technically you don't even need the list comprehension (though in our use case we process a list items using one of several classes). The following will also error:

def f(a: Union[type[A], type[B]]) -> Union[list[A], list[B]]:
    return [a()]

Is my understanding of the situation incorrect? Is there a better way to write the function to ensure the output signature (I do indeed want one of several homogenous lists from this function).

1

There are 1 best solutions below

4
chepner On BEST ANSWER

The typechecker doesn't know which of A or B is passed, so it cannot determine the type of a() more specifically than A|B. As a result, the list has type list[A|B], not list[A] | list[B].

You should define a generic function with a constrained type variable instead.

from typing import TypeVar


T = TypeVar('T', A, B)


def f(items, t: type[T]) -> list[T]:
    return [t() for x in items]

From the perspective of the type checker, you have now defined two functions, one of type Callable[[Any, Type[A]], list[A]] and one of type Callable[[Any, Type[B]], list[B]], rather than a single function of type Callable[[Any, Type[A]|Type[B]], list[A]|list[B]]. The type of argument for t will actually determine which of the two will be called and should be type-checked:

f(something, A)  # call the Callable[[Any, Type[A]], list[A]] version
f(something, B)  # call the Callable[[Any, Type[B]], list[B]] version