implementing if-then-elif-then-else in jax

638 Views Asked by At

I'm just starting to use JAX, and I wonder—what would be the right way to implement if-then-elif-then-else in JAX/Python? For example, given input arrays: n = [5, 4, 3, 2] and k = [3, 3, 3, 3], I need to implement the following pseudo-code:

def n_choose_k_safe(n, k):
    r = jnp.empty(4)
    for i in range(4):
        if n[i] < k[i]:
            r[i] = 0
        elif n[i] == k[i]:
            r[i] = 1
        else:
            r[i] = func_nchoosek(n[i], k[i])
    return r

There are so many choices like vmap, lax.select, lax.where, jax.cond, lax.fori_loop, etc., so that it is hard to decide on specific combinations of the utilities to use. By the way, k can be a scalar (if that makes it simpler).

2

There are 2 best solutions below

2
Valentin Goldité On

First you can vectorize the function func_nchoosek to accept n and k as flat vectors (we assume that func_nchoosek accepts inputs of shape (1, ) otherwise it should be necessary to do so first of all) then:

func_nchoosek_vect = jax.vmap(func_nchoosek, (0, 0), 0)

Now func_nchoosek_vect([n1, n2, ...], [k1, k2, ...]) = [func_nchoosek(n1, k1), func_nchoosek(n2, k2), ...] the operation are done element-wise (similar to zip).

If k is a single scalar, you can use this instead:

func_nchoosek_vect = jax.vmap(func_nchoosek, (0, None), 0)

Now you can use the function jnp.where to select the data you want. It is like lax.select but more flexible. The function is compatible with jit compilation and preserve gradients (under some further assumptions) as long as you use it with 3 arguments (to have deterministic shapes).

def n_choose_k_safe(n: jnp.array, k: jnp.array) -> jnp.array:
  """Choose k among n with safety."""
  r = jnp.where(n > k, func_nchoosek_vect(n, k), -1)
  r = jnp.where(n == k, 1, r)
  r = jnp.where(n < k, 0, r)
  return r
1
jakevdp On

There's a slightly more compact way to express the solution in Valentin's answer, using jax.numpy.select:

def n_choose_k_safe(n, k):
  return jnp.select(condlist=[n > k, n == k],
                    choicelist=[jnp.vectorize(func_nchoosek)(n, k), 1],
                    default=0)

For input arrays of length 4, this should return the same result as your original code, assuming func_nchoosek is compatible with jax.vmap. Using vectorize here in place of vmap will make the function also compatible with scalar inputs for k, without having to manually set the in_axes argument.