How to replace nested IF/ELSE branches with SIMD (SSE or AVX)?

166 Views Asked by At

EDIT x 2

  • Added more comprehensive function which returns an abstract register class: the function outputs a register full of floats. I don't care the actual length - SSE, AVX... - because Google Highway will figure that out for me.
  • Working code available on Godbolt
  • test1, test2, test3 could also be masks

EDIT

  • Returned values are not known at compile time (these trivial values are there just to make the branching immediate to understand)
  • I fixed the branching that was so immediate to understand that I got it wrong (:

I'm new to SIMD and I'm using Google Highway to achieve a portable (x86 and ARM) solution, so I'm writing this question in general terms.

I'm trying to speedup this C/C++ code with SIMD instructions

  const bool test1 = foo(input1) > 0; // unpredictable 
  const bool test2 = foo(input2) > 0; // unpredictable 
  const bool test3 = foo(input3) > 0; // unpredictable 

  const RegisterWithFourFloats out0;  // not known at compile time
  const RegisterWithFourFloats out1;  // not known at compile time
  const RegisterWithFourFloats out12; // not known at compile time
  const RegisterWithFourFloats out13; // not known at compile time
  const RegisterWithFourFloats out123;// not known at compile time

if (test1)
  if(test2) 
    if (test3)
       return out123; 
    else
       return out12; 
  else
    if (test3)
       return out13; 
    else
       return out1; 
else 
  return out0;  
 

So the keys are

test 1 test 2 test 3 returned value mask name
0 0 0 0 mask1
1 1 1 123 mask2
1 0 1 13 mask3
1 1 0 12 mask4
1 0 0 1 mask5

I hope this question is clear and I'm very happy to improve it. I guess the solution should be a general SIMD strategy that I can test.

I have tried two solutions that return the correct output but are slower:

  1. Flatten the IF/ELSE cascade by removing nested IF/ELSE. This creates more jump instructions and therefore gives poor performance.
  2. Go branch free with IfThenElse function that basically creates AND/OR masks:
Vec result = [test1, test2, test3]
output = IfThenElse( Xor( mask1, result), 0,   output)
output = IfThenElse( Xor( mask2, result), 123, output)
output = IfThenElse( Xor( mask3, result), 12,  output)
output = IfThenElse( Xor( mask4, result), 13,  output) 
output = IfThenElse( Xor( mask5, result), 1,   output) 

return output;

I took a bit of a shortcut here to keep the question concise, but the idea of The Xor operator is that it results in a mask with true values IFF maskX is block-wise equal to result. Therefore, the value of output is updated only when maskX is block-wise equal to result. The result is correct, but the runtime cost higher.

1

There are 1 best solutions below

6
Soonts On

I don’t know how to solve this with highway.

If you’re sure these branches are unpredictable, the code is on the performance-critical path, and you have measured the branches are the issue, here’s possible solutions in C++ without Highway.

The following method uses blend instructions, and requires at least SSE 4.1:

// returns ( cond ? b : a ), without branches
inline __m128 select( __m128 a, __m128 b, bool cond )
{
    __m128i i = _mm_cvtsi32_si128( cond ? -1 : 0 );
    i = _mm_shuffle_epi32( i, _MM_SHUFFLE( 0, 0, 0, 0 ) );
    return _mm_blendv_ps( a, b, _mm_castsi128_ps( i ) );
}

__m128 select5_blend( __m128 out0, __m128 out1, __m128 out12, __m128 out13,
    __m128 out123, bool test1, bool test2, bool test3 )
{
    __m128 t0 = select( out13, out123, test2 );
    __m128 t1 = select( out1, out12, test2 );
    __m128 t = select( t1, t0, test3 );
    return select( out0, t, test1 );
}

NEON version:

// returns ( cond ? b : a ), without branches
inline float32x4_t select( float32x4_t a, float32x4_t b, bool cond )
{
    uint32_t mask = cond ? ~0u : 0u;
    uint32x4_t cv = vdupq_n_u32( mask );
    return vbslq_f32( b, a, cv );
}

Here’s another method which uses stores followed by loads. It’s easily portable to highway, but I would expect it to be slower because it loads from memory immediately after the data is stored there.

__m128 select5_mem( __m128 out0, __m128 out1, __m128 out12, __m128 out13,
    __m128 out123, bool test1, bool test2, bool test3 )
{
    __m128 arr[ 5 ];
    arr[ 0 ] = out1;
    arr[ 1 ] = out13;
    arr[ 2 ] = out12;
    arr[ 3 ] = out123;
    arr[ 4 ] = out0;

    size_t idx = test3 ? 1 : 0;
    idx |= test2 ? 2 : 0;
    idx = test1 ? idx : 4;
    return arr[ idx ];
}