I'm just starting to use JAX, and I wonder—what would be the right way to implement if-then-elif-then-else in JAX/Python? For example, given input arrays: n = [5, 4, 3, 2] and k = [3, 3, 3, 3], I need to implement the following pseudo-code:
def n_choose_k_safe(n, k):
r = jnp.empty(4)
for i in range(4):
if n[i] < k[i]:
r[i] = 0
elif n[i] == k[i]:
r[i] = 1
else:
r[i] = func_nchoosek(n[i], k[i])
return r
There are so many choices like vmap, lax.select, lax.where, jax.cond, lax.fori_loop, etc., so that it is hard to decide on specific combinations of the utilities to use. By the way, k can be a scalar (if that makes it simpler).
First you can vectorize the function
func_nchoosekto acceptnandkas flat vectors (we assume thatfunc_nchoosekaccepts inputs of shape (1, ) otherwise it should be necessary to do so first of all) then:Now
func_nchoosek_vect([n1, n2, ...], [k1, k2, ...]) = [func_nchoosek(n1, k1), func_nchoosek(n2, k2), ...]the operation are done element-wise (similar tozip).If
kis a single scalar, you can use this instead:Now you can use the function
jnp.whereto select the data you want. It is likelax.selectbut more flexible. The function is compatible with jit compilation and preserve gradients (under some further assumptions) as long as you use it with 3 arguments (to have deterministic shapes).