Narrow typing for short circuit `or`, when I have two possible cases

181 Views Asked by At

I'm trying to annotate the following code.

The function designed to work when both zone and zones defined, or when file is defined (but not both):

def get_file(zone: str, zones: dict[str, str]) -> pathlib.Path:
    pass

def connect(
        zone: str | None = None,
        zones: dict[str, str] | None = None,
        file: pathlib.Path | None = None,
) -> bool:
    file = file or get_file(zone, zones)

But it makes mypy angry -

1. Argument of type "str | None" cannot be assigned to parameter "zone" of type "str" in function "_get_vpn_file"
     Type "str | None" cannot be assigned to type "str"
       Type "None" cannot be assigned to type "str"
2. Argument of type "dict[str, str] | None" cannot be assigned to parameter "zones" of type "dict[str, str]" in function "_get_vpn_file"
     Type "dict[str, str] | None" cannot be assigned to type "dict[str, str]"
       Type "None" cannot be assigned to type "dict[str, str]"

Then I tried to make some aggressive type narrowing:

def _check_params_are_ok(
    zone: str | None, zones: dict[str, str] | None, file: pathlib.Path | None,
) -> tuple[str, dict[str, str], None] | tuple[None, None, pathlib.Path]:
    if zone is not None and file is not None:
        raise ValueError("Pass `file` or `zone`, but not both.")

    if zone is not None and zones is None:
        raise ValueError("connect: Must define `zones` when `zone` is defined.")

    if zone is None and file is None:
        raise ValueError("connect: Must define `zone` or `file`.")

    assert file is not None or (zone is not None and zones is not None)

    # Type narrowing
    if zone is not None and zones is not None and file is None:
        return zone, zones, file
    if zone is None and zones is None and file is not None:
        return zone, zones, file

    raise NotImplementedError("This error from _check_params_ok shouldn't happen.")


def connect(
        zone: str | None = None,
        zones: dict[str, str] | None = None,
        file: pathlib.Path | None = None,
) -> bool:
    zone, zones, file = _check_params_are_ok(zone, zones, file)
    file = file or get_file(zone, zones)

And mypy still shows the same errors.

Mypy still shows the same errors even when adding very clear assertions:

    zone, zones, file = _check_params_are_ok(zone, zones, file)
    if file is None:
        assert zone is not None and zones is not None
    file = file or get_file(zone, zones)

The best solution I found so far is to cast the types inline, but it effects the code readability and make the line hard to read:

    file = file or get_file(cast(str, zone), cast(dict[str, str], zones))

Is there any good way to narrow the types?

2

There are 2 best solutions below

3
Stefan Marinov On

In your code when you do the — as you put it — very clear assertions, you can assign to file only in the body of the if statement:

if file is None:
    assert zone is not None and zones is not None
    file = get_file(zone, zones)

# file is now a pathlib.Path object
assert file.is_file()  # this is now valid

Then you will be very explicit and mypy shouldn't complain anymore about this.

0
Daniil Fajnberg On

The answer by Stefan Marinov is correct. Assertions are useful for these situations.

I would just like to add for the sake of completeness that in such constructs, when you know that only specific combinations of call arguments are supposed to occur, it is advisable to define your overload call variants:

from pathlib import Path
from typing import overload

def get_file(zone: str, zones: dict[str, str]) -> Path:
    return NotImplemented

@overload
def connect(zone: str, zones: dict[str, str], file: None = None) -> bool:
    ...

@overload
def connect(zone: None, zones: None, file: Path) -> bool:
    ...

def connect(
    zone: str | None = None,
    zones: dict[str, str] | None = None,
    file: Path | None = None,
) -> bool:
    if file is None:
        assert zone is not None and zones is not None
        file = get_file(zone, zones)
    return NotImplemented

This allows type checkers (and by extension your IDE) to alert you to incorrect usage of connect. For example, these two calls will both cause mypy errors, when the overload signatures are defined properly:

connect(None, None, None)
connect("abc", {"x": "y"}, Path("."))

Although, for what it's worth, I would suggest at least reconsidering, if this design is truly what you want. Calling a function as connect(None, None, Path(...)) seems very unnatural to me. If this is indeed the way the function is supposed to be called, I would probably at least opt for making all arguments keyword-only, thus allowing more explicit and readable calls:

from pathlib import Path
from typing import overload

def get_file(zone: str, zones: dict[str, str]) -> Path:
    return NotImplemented

@overload
def connect(*, zone: str, zones: dict[str, str]) -> bool:
    ...

@overload
def connect(*, file: Path) -> bool:
    ...

def connect(
    *,
    zone: str | None = None,
    zones: dict[str, str] | None = None,
    file: Path | None = None,
) -> bool:
    if file is None:
        assert zone is not None and zones is not None
        file = get_file(zone, zones)
    return NotImplemented


connect(file=Path("."))
connect(zone="abc", zones={"x": "y"})

But this may just be subjective. You know better, what you need for your purposes.

EDIT: Just realized that the implementation has None defaults for all arguments, so calling it as I suggested was already possible. Still, I would argue that enforcing keyword-only would be more explicit and "explicit is better than implicit".