How to recreate axis = 2 in np.repeat() for Numba (repeat array in the last dimension)

73 Views Asked by At

I'm trying to convert my code to Numba friendly implementation, however I keep running into errors with the axis argument (as it is not supported). Specifically, I need to use the np.repeat() function in axis=2, or more generally how to repeat the array along the last dimension.

In numpy my code is:

original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims((5)*original, axis=2), no_repeats, axis=2)

How could I rewrite this in a Numba friendly way?

I have tried to use np.dstack:

expanded_original = np.expand_dims((5)*original, axis=2)
big_original = np.dstack([expanded_original]*no_repeats)

But of course lists aren't a supported datatype. How could I go about this in the most efficient way?

2

There are 2 best solutions below

3
Salvatore Daniele Bianco On

I don't know what are you exactly trying to do but I guess you want to reproduce the big_original array inside a @njit numba compiled function. Right?

I so:

@njit
def repeat_original(original, no_repeats):
    big_original = np.zeros((*original.shape, no_repeats))
    for i in range(big_original.shape[-1]):
        big_original[...,i] = (5)*original
    return big_original
repeat_original(original, no_repeats)

If this is not the answer you were expecting, please try to better specify your problem (e.g. what expandedGradientMatrix is) and your expected output.

2
Felix Zimmermann On

If the input is always 2D, or by axis=2 you mean "after the last dimension", this will work:

import numpy as np
import numba


original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims(original, axis=2), no_repeats, axis=2)


@numba.njit()
def repeatnumba(original,no_repeats):
  repeat=original.repeat(no_repeats).reshape(*original.shape,no_repeats )
  return repeat

big_numba = repeatnumba(original,no_repeats)

print(np.allclose(big_original, big_numba))

It reproduces the same result as your numpy code. Please note that this depends on the last dimension of the expected result being dimension 2. If your true number of dimensions is different, you might want to use np.transpose with a list of dimensions, for example:

repeat=np.transpose(repeat,(0,1,-1,2)) 

if your input is 3d and you still want to repeat along dim 2.

If you meant "along the last dimension", please consider updating your title and problem description such that the question might be of more general use.