Currently jax.lax.cond works for one boolean condition. Is there a way to extend it to multiple boolean conditions?
As an example, below is an untraceable function:
def func(x):
if x < 0: return x
elif (x >= 0) & (x < 1): return 2*x
else: return 3*x
How to write this function in JAX in a traceable way?
One compact way to write something like this is using
jnp.select: