Dataclass comparison methods giving unexpected results

61 Views Asked by At

I just learnt how dataclasses work and was messing with it, until I ran into a issue with the comparison logic, the first print prints False, even though it should print True as overall of player1 is equal to player2.
Here is the code:

from dataclasses import dataclass, field


@dataclass(order=True)
class Player:
    overall: int = field(init=False)
    name: str
    player_class: str
    health: int
    damage: int

    def __post_init__(self):
        self.overall = self.health + self.damage


player1 = Player('Player1', 'Mage', 200, 400)
player2 = Player('Player2', 'Ranger', 300, 300)

print(player1 >= player2)  # prints False, even though it should print True.
print(player1 <= player2)  # prints True, as expected.

Why is this happening? According to the answer in this very similar post, the logic would be that the fields of both player1 and player2 are compared as tuples where order matters and the order is the order in which fields are defined in the dataclass, since overall is equal for both, it moves on to the next field, name and hence evaluates to False, but shouldn't it return True the instant both overall were evaluated to be equal?

2

There are 2 best solutions below

0
Jasmijn On BEST ANSWER

That's how tuple inequalities work in Python: if the first element is equal, then Python looks at the second element and so on. So player1 >= player2 is False because 'Player1' < 'Player2'

If you want <= and >= to be based on only the value of overall, you could write:

@dataclass(order=True)
class Player:
    overall: int = field(init=False)
    name: str = field(compare=False)
    player_class: str = field(compare=False)
    health: int = field(compare=False)
    damage: int = field(compare=False)

    def __post_init__(self):
        self.overall = self.health + self.damage
2
Muhammad Nasir On

Try this, It's also working..

from dataclasses import dataclass, field

@dataclass
class Player:
    name: str
    player_class: str
    health: int
    damage: int
    overall: int = field(init=False)

    def __post_init__(self):
        self.overall = self.health + self.damage

    def __eq__(self, other):
        return self.overall == other.overall

    def __lt__(self, other):
        return self.overall < other.overall

    def __le__(self, other):
        return self.overall <= other.overall

player1 = Player('Player1', 'Mage', 200, 400)
player2 = Player('Player2', 'Ranger', 300, 300)

print(player1 == player2)  
print(player1 <= player2)