Chess game optimization: inefficient performance despite attack tiles update

58 Views Asked by At

I am developing a chess game and recently made a change in my code to optimize move generation by updating attack tiles and pins more efficiently. Initially, I used to calculate the attack tiles for the enemy side every time I generated moves. However, I modified the code to calculate this information once during the initialization of the Board object and then update the relevant squares as needed during each move.

Here is the update function and the relevant data structures:

private HashMap<Integer, List<List<List<Integer>>>> attackedTiles;
private HashMap<Integer, List<List<Integer>>> pins;

private void updateAttackedTilesAndPins(int oldIndex, int newIndex) {
    int[] colors = new int[]{Piece.WHITE, Piece.BLACK};
    int pieceIndex = Piece.index(tile[newIndex]);

    for (int color : colors) {
        List<List<List<Integer>>> attackedTilesColor = attackedTiles.get(color);

        for (int i = 0; i < attackedTilesColor.size(); i++) {
            //no point in looking at the moved piece, since we will recalculate the attacked tiles for it anyway
            if (i == pieceIndex) {
                continue;
            }

            //update every line of sight that could see the moved piece
            for (List<Integer> lineOfSight : attackedTilesColor.get(i)) {
                if (lineOfSight.contains(oldIndex) || lineOfSight.contains(newIndex)) {
                    updateLineOfSight(lineOfSight);
                }
            }
        }
    }

    //calculate the attacked tiles for the moved piece
    attackedTiles.get(turn).set(pieceIndex, calculateAttackedTiles(newIndex));

    pins.get(turn).clear();

    //check if any piece is pinning anything
    List<Integer> piecePositionsTurn = piecePositions.get(turn);
    for (int piecePosition : piecePositionsTurn) {
        if (piecePosition == -1) {
            continue;
        }

        List<Integer> pinLine = calculatePinLine(piecePosition);
        if (!pinLine.isEmpty()) {
            pins.get(turn).add(pinLine);
        }
    }
}

I know attackedTiles is ugly, so here is what it might look like, where each inner list is a line of sight of a piece in a direction:

Attacked tiles white:
[[48, 41]]
[[58, 49], [58, 51]]
[[59, 50], [59, 51], [59, 52], [59, 58], [59, 60]]
[[60, 51], [60, 52], [60, 53], [60, 59], [60, 61]]
[[61, 52], [61, 54]]
[[62, 45], [62, 47], [62, 52]]
[[63, 55], [63, 62]]
Attacked tiles black:
[[0, 1], [0, 8]]
[[1, 11], [1, 16], [1, 18]]
[[2, 9], [2, 11]]
[[3, 2], [3, 4], [3, 10], [3, 11], [3, 12]]
[[4, 3], [4, 5], [4, 11], [4, 12], [4, 13]]
[[5, 12], [5, 14]]
[[6, 12], [6, 21], [6, 23]]

I had a lot of faith in this approach, because I felt like it should be faster than before. After running some test I found that it is just as fast. Not slower, not faster. Which I don't really understand.

I did run a profiler though and these are the results Profiler results

From what I can tell, the update function itself is causing slowdowns, not the child functions. I am wondering if my approach is flawed or if I am making inefficient use of cache locality. The attackedTiles data structure might not be the most elegant, but is it also slow to iterate through? Do I just need a more efficient version of checking if a piece is seen by another, so that maybe I can prune that list or something and I don't have to look at it?

1

There are 1 best solutions below

2
jvatechs On

I'm certainly not entirely sure about the significant increase in efficiency but I think you should to try the manipulations and see what'll happen. I changed the 2 parts of your code:

    //update every line of sight that could see the moved piece
    for (List<Integer> lineOfSight : attackedTilesColor.get(i)) {
        if (lineOfSight.contains(oldIndex) || lineOfSight.contains(newIndex)) {
            updateLineOfSight(lineOfSight);
        }
    }

    //check if any piece is pinning anything
    List<Integer> piecePositionsTurn = piecePositions.get(turn);
    for (int piecePosition : piecePositionsTurn) {
        if (piecePosition == -1) {
            continue;
        }

        List<Integer> pinLine = calculatePinLine(piecePosition);
        if (!pinLine.isEmpty()) {
            pins.get(turn).add(pinLine);
        }
    }

and transformed it into .stream():

//update every line of sight that could see the moved piece
attackedTilesColor.get(i).stream()
        .filter(lineOfSight -> lineOfSight.contains(oldIndex) || lineOfSight.contains(newIndex))
        .forEach(this::updateLineOfSight);

//check if any piece is pinning anything
piecePositions.get(turn).stream()
        .filter(piecePosition -> piecePosition != -1)
        .map(this::calculatePinLine)
        .filter(pinLine -> !pinLine.isEmpty())
        .forEach(pins.get(turn)::add);

UPDATE:

Let's try this with replaced List<List<List<Integer>>> to List<Set<List<Integer>>> for attackedTilesbecause Set has a time complexity of O(1) for this operation, compared to O(n) for List :

private HashMap<Integer, List<Set<List<Integer>>>> attackedTiles;
private HashMap<Integer, List<List<Integer>>> pins;

private void updateAttackedTilesAndPins(int oldIndex, int newIndex) {
    int[] colors = new int[]{Piece.WHITE, Piece.BLACK};
    int pieceIndex = Piece.index(tile[newIndex]);

    for (int color : colors) {
        List<Set<List<Integer>>> attackedTilesColor = attackedTiles.get(color);

        for (int i = 0; i < attackedTilesColor.size(); i++) {
            if (i == pieceIndex) {
                continue;
            }

            //update every line of sight that could see the moved piece
            Set<List<Integer>> linesOfSight = attackedTilesColor.get(i);
            for (List<Integer> lineOfSight : linesOfSight) {
                if (lineOfSight.contains(oldIndex) || lineOfSight.contains(newIndex)) {
                    updateLineOfSight(lineOfSight);
                }
            }
        }
    }

    //calculate the attacked tiles for the moved piece
    attackedTiles.get(turn).set(pieceIndex, new HashSet<>(calculateAttackedTiles(newIndex)));

    pins.get(turn).clear();

    List<Integer> piecePositionsTurn = piecePositions.get(turn);
    for (int piecePosition : piecePositionsTurn) {
        if (piecePosition == -1) {
            continue;
        }

        List<Integer> pinLine = calculatePinLine(piecePosition);
        if (!pinLine.isEmpty()) {
            pins.get(turn).add(pinLine);
        }
    }
}