keep track of the parity of a permutation during sort?

168 Views Asked by At

std::sort (and related) is one of the crown jewels of the STL. However, when I try to use it in the context of numerical linear algebra, I find a glaring problem with it, which prevents me from using it.

The key is that in mathematics, keeping track of the parity (even/odd) of a permutation is usually relevant.

I can think of 3 ways to keep track of the parity:

  1. zip the sorting range with a trivial 0, 1... n sequence and read the permutation after sorting. (this is a sure way, but it does an unreasonable amount of extra work).
  2. hack the comparison operation to flip a sign each time a given pair is unsorted. Of course, this is bound to fail in general, as I don't know if the number of permutations is the same as the number of "failed" comparisons.
  3. wrap the sequence in a type with a custom swap that has the side-effect of keeping a count or the sign of the permutations so far.

Of course, these are all horrible solutions; I am looking for something better if it is available: Is there an algorithm close to std::sort that could allow me to keep track of the parity without having to reimplement sort from scratch?

1

There are 1 best solutions below

2
lastchance On

As suggested, you can just sort indices (with comparator function based on the original array at those indices). The parity can then be checked by an O(n) function following cycles in the permuted indices.

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;

bool parity( const vector<int> &V )    // returns true for even, false for odd
{                                      // V is a permutation of 0, 1, ..., n-1 (NOTE: zero-based!)
   vector<bool> visited( V.size(), false );
   bool result = true;
   for ( int start = 0; start < V.size(); start++ )
   {
      if ( visited[start] ) continue;
      visited[start] = true;
      for ( int j = V[start]; j != start; j = V[j] )
      {
         result = !result;
         visited[j] = true;
      }
   }
   return result;
}


int main()
{
   vector<int> test = { { 17, 0, 19, 13, 4, -5, -6, 27, -31 } };

   vector<int> indices( test.size() );   iota( indices.begin(), indices.end(), 0 );
   sort( indices.begin(), indices.end(), [&test]( int i, int j ){ return test[i] < test[j]; } );
                   
   cout << "Sorted array: ";
   for ( int i : indices ) cout << test[i] << " ";
   cout << "\nParity: " << ( parity( indices ) ? "even" : "odd" ) << '\n';
}
Sorted array: -31 -6 -5 0 4 13 17 19 27 
Parity: odd