I would like to find the permutation parity sign for a given batch of vectors (in Python /Jax).
n = jnp.array([[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[1., 1., 0., 0.],
[1., 1., 0., 0.]],
[[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.]],
[[0., 1., 1., 0.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[0., 1., 1., 0.]]])
sorted_index = jax.vmap(sorted_idx)(n)
sorted_perms = jax.vmap(jax.vmap(sorted_perm, in_axes=(0, 0)), in_axes=(0,0))(n, sorted_index)
parities = jax.vmap(parities)(sorted_index)
I expect the following solution:
sorted_elements= [[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]]]
parities = [[1, 1, 1, 1],
[-1, -1, -1, -1],
[1, 1, 1, 1]]
I tried the following:
# sort the array and return the arg_sort indices
def sorted_idx(permutations):
sort_idx = jnp.argsort(permutations)
return sort_idx
# sort the permutations (vectors) given the sorted_indices
def sorted_perm(permutations, sort_idx):
perm = permutations[sort_idx]
return perm
# Calculate the permutation cycle, from which we compute the permutation parity
@jax.vmap
def parities(sort_idx):
length = len(sort_idx)
elements_seen = jnp.zeros(length)
cycles = 0
for index in range(length):
if elements_seen[index] == True:
continue
cycles += 1
current = index
if elements_seen[current] == False:
elements_seen.at[current].set(True)
current = sort_idx[current]
is_even = (length - cycles) % 2 == 0
return +1 if is_even else -1
But I get the following: parities= [[1 1 1 1], [1 1 1 1], [1 1 1 1]]
I get for each permutation vector a parity factor of 1, which is wrong....
The reason your routine doesn't work is because you're attempting to vmap over Python control flow, and this must be done very carefully (See JAX Sharp Bits: Control Flow). I suspect it would be a bit complicated to try to construct your iterative parity approach in terms of
jax.laxcontrol flow operators, but there might be another way.The parity of a permutation is related to the determinant of its cycle matrix, and the jacobian of a
sorthappens to be equivalent to that cycle matrix, so you could (ab)use JAX's automatic differentiation of the sort operator to compute the parities very concisely:This does end up being O[N^3] where N is the length of the permutations, but due to the nature of XLA computations, particularly on accelerators like GPU, the vectorized approach will likely be more efficient than an iterative approach for reasonably-sized N.
Also note that there's no reason to compute the
sorted_indexwith this implementation; you could callcompute_paritydirectly on your arrayninstead.