Efficiently find neighbors in python

155 Views Asked by At

I have a three-dimensional grid. For each point on the grid, I want to find its nearest neighbours. Since my grid is uniformly sampled, I just want to gather the immediate neighbors.

Sample grid:

enter image description here

Required neighbors:

For a given point X, I need the following neighbors:

enter image description here

My Current Working Code:

import numpy as np
import cProfile

class Neighbours:
    # Get neighbors
    @classmethod
    def get_neighbour_indices(cls, row, col, frame, distance=1):                
        # Define the indices for the neighbor pixels    
        r = np.linspace(row - distance, row + distance, 2 * distance + 1)
        c = np.linspace(col - distance, col + distance, 2 * distance + 1)
        f = np.linspace(frame - distance, frame + distance, 2 * distance + 1)
        nc, nr, nf = np.meshgrid(c, r, f)
        neighbors = np.vstack((nr.flatten(), nc.flatten(), nf.flatten())).T
    
        # Filter out valid neighbor indices within the array bounds
        valid_indices = (neighbors[ :, 0] >= 0) & (neighbors[ :, 0] < nRows) & (neighbors[ :, 1] >= 0) & (neighbors[ :, 1] < nCols) & (neighbors[ :, 2] >= 0) & (neighbors[ :, 2] < nFrames) 

        # Return the valid neighbor indices
        valid_neighbors = neighbors[valid_indices]
        return valid_neighbors

    @classmethod
    def MapIndexVsNeighbours(cls):
        neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
        for frame in range(nFrames):
            for row in range(nRows):
                for col in range(nCols):        
                    neighbour_indices = cls.get_neighbour_indices(row, col, frame, distance=1)        
                    flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                    neighbours_info[flat_idx] = neighbour_indices                            

        return neighbours_info


########################------------------main()-------##################
####--run
if __name__ == "__main__":  
    nRows = 151
    nCols = 151
    nFrames = 24
    cProfile.run('Neighbours.MapIndexVsNeighbours()', sort='cumulative')  
    print() 

Problem: For larger grids (e.g. 201 x 201 x 24), the program takes very long time. In the profiling results using cProfile, I can see that meshgrid() in the get_neighbour_indices() takes quite long. All in all, this is not an efficient implementation. Furthermore, I tried to execute the MapIndexVsNeighbours() on a separate thread but due to GIL lock, it does not really execute in parallel. So, something that can be execute in parallel would be a desired implementation.

1

There are 1 best solutions below

3
Andrej Kesely On

You can speed-up the computation for example with :

from timeit import timeit

from numba import njit


@njit
def get_arr(x, y, f, w, h, distance, frames):
    _x_from = max(0, x - distance)
    _x_to = min(w - 1, x + distance)

    _y_from = max(0, y - distance)
    _y_to = min(h - 1, y + distance)

    _z_from = max(0, f - distance)
    _z_to = min(frames - 1, f + distance)

    shape = ((_x_to + 1) - _x_from) * ((_y_to + 1) - _y_from) * ((_z_to + 1) - _z_from)
    out = np.empty(shape=(shape, 3), dtype=np.int32)

    i = 0
    for _x in range(_x_from, _x_to + 1):
        for _y in range(_y_from, _y_to + 1):
            for _z in range(_z_from, _z_to + 1):
                out[i, 0] = _x
                out[i, 1] = _y
                out[i, 2] = _z
                i += 1
    return out


def MapIndexVsNeighbours_numba(nRows, nCols, nFrames):
    neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
    for frame in range(nFrames):
        for row in range(nRows):
            for col in range(nCols):
                neighbour_indices = get_arr(row, col, frame, nRows, nCols, 1, nFrames)
                flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                neighbours_info[flat_idx] = neighbour_indices
    return neighbours_info


nRows = 151
nCols = 151
nFrames = 24

v1 = MapIndexVsNeighbours_numba(nRows, nCols, nFrames)
v2 = Neighbours.MapIndexVsNeighbours()
assert all(np.allclose(a, b) for a, b in zip(v1, v2))

t_numba = timeit(
    "MapIndexVsNeighbours_numba(nRows, nCols, nFrames)", number=1, globals=globals()
)
t_original = timeit("Neighbours.MapIndexVsNeighbours()", number=1, globals=globals())

print(f"{t_numba=}")
print(f"{t_original=}")

Prints on my machine AMD 5700x:

t_numba=0.46780765103176236
t_original=31.86005672905594

  • With 201 x 201 x 24 the numba function took 1.1439994741231203
  • With 1024 x 1024 x 24 the numba function took 22.353679422987625