Type hinting a dictionary with a given set of keys and values

107 Views Asked by At

Suppose I have an Enum class in Python:

class MyEnum(Enum):
    A = "a"
    B = "b"

I have a function that is returning for each of the possible (two in this case) enum values a given type: suppose for both of them it is returning a DataFrame. I want to type hint, and for this I was using TypedDict in this way:

import pandas as pd
from typing import TypedDict

class ReturnedType(TypedDict):
    MyEnum.A.value: pd.DataFrame
    MyEnum.B.value: pd.DataFrame

and then:

def foo(...) -> ReturnedType

but apparently TypedDict does not accept defining field names with other variables and this mypy checks to fail.

What is the most pythonic way to type hint such a function in this case?

Here is a MWE:

from typing import Dict, TypedDict
from enum import Enum

class MyEnum(Enum):
    A = 'a'
    B = 'b'

class MyClass(TypedDict):
    """The class defines the shape of the dictionary output
    by any RoadCodec"""

    MyEnum.A.value: int
    MyEnum.B.value: float

def foo() -> MyClass:
    res: MyClass = {MyEnum.A.value: 3, MyEnum.B.value: 4.4}

By launching mypy checks I get the error

Invalid statement in TypedDict definition; expected "field_name: field_type" [misc]

Also note that I am running under Python3.8 so StrEnum is not available

1

There are 1 best solutions below

0
Robin Gugel On

sadly there's no pretty way to achieve this.

There exists some cumbersome way though. I haven't tested this with mypy yet, but at least with pylance/pyright/vscode this works:

from enum import Enum
from typing import Dict, Union, overload, Literal
class MyEnum(Enum):
    A = "a"
    B = "b"

class ReturnedType(Dict[MyEnum, Union[int, str]]):

    @overload
    def __getitem__(self, __key: Literal[MyEnum.A]) -> int:
        ...
    
    @overload
    def __getitem__(self, __key: Literal[MyEnum.B]) -> str:
        ...

    def __getitem__(self, __key: MyEnum) -> Union[int, str]:
        return super()[__key]
    
    @overload
    def get(self, __key: Literal[MyEnum.A]) -> Union[int, None]:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.A], __default: int) -> int:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.B]) -> Union[str, None]:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.B], __default: str) -> str:
        ...

    def get(self, __key: MyEnum, __default: Union[int, str, None] = None) -> Union[int, str, None]:
        return super().get(__key, __default)

    @overload
    def __setitem__(self, __key: Literal[MyEnum.A], __value: int):
        ...

    @overload
    def __setitem__(self, __key: Literal[MyEnum.B], __value: str):
        ...
    
    def __setitem__(self, __key: MyEnum, __value: Union[int, str]):
        super()[__key] = __value
    
def foo() -> ReturnedType:
    rv = ReturnedType()
    rv[MyEnum.A] = 3
    rv[MyEnum.B] = "hello"
    return rv

d: int = foo()[MyEnum.A]
a: str = foo()[MyEnum.A] # incompatible!