Dot-product groups of 4 bytes against 4 small constants, over an array of bytes (efficiently using SIMD)?

124 Views Asked by At

I have a peculiar requirement that needs to be fulfilled efficiently. (SIMD, perhaps?)

src is an array of bytes. Every group of 4 bytes in the array need to be processed as:

  • Multiply low nibble of src[0] by a constant number A.
  • Multiply low nibble of src[1] by a constant number B.
  • Multiply low nibble of src[2] by a constant number C.
  • Multiply low nibble of src[3] by a constant number D.

Sum the four parts above to give result.

Move on to next 4 set of bytes and re-compute result (rinse & repeat till end of byte array).

result is guaranteed to be small (even fit in a byte) owing to all numbers involved being very small. However, the data type for result can be flexible to support an effecient algorithm.

Any suggestions / tips / tricks to go faster than the following pseudo-code?:

for (int i=0; i< length; i+=4)
{
  result = (src[i] & 0x0f) * A + (src[i+1] & 0x0f) * B + (src[i+2] & 0x0f) * C + (src[i+3] & 0x0f) * D;
}

BTW, result then forms an index into a higher-order array.

This particular loop is so crucial that implementation language is no barrier. Can choose language out of C#, C or MASM64

1

There are 1 best solutions below

1
Soonts On

Here’s an example how to do that efficiently with SSE intrinsics.

#include <stdint.h>
#include <emmintrin.h>  // SSE 2
#include <tmmintrin.h>  // SSSE 3
#include <smmintrin.h>  // SSE 4.1

// Vector constants for dot4Sse function
struct ConstantVectorsSse
{
    __m128i abcd;
    __m128i lowNibbleMask;
    __m128i zero;
};

// Pack 4 bytes into a single uint32_t value
uint32_t packBytes( uint32_t a, uint32_t b, uint32_t c, uint32_t d )
{
    b <<= 8;
    c <<= 16;
    d <<= 24;
    return a | b | c | d;
}

// Initialize vector constants for dot4Sse function
struct ConstantVectorsSse makeConstantsSse( uint8_t a, uint8_t b, uint8_t c, uint8_t d )
{
    struct ConstantVectorsSse cv;
    cv.abcd = _mm_set1_epi32( (int)packBytes( a, b, c, d ) );
    cv.lowNibbleMask = _mm_set1_epi8( 0x0F );
    cv.zero = _mm_setzero_si128();
    return cv;
}

// Dot products of 4 groups of 4 bytes in memory against 4 small constants
// Returns a vector of 4 int32 lanes
__m128i dot4Sse( const uint8_t* rsi, const struct ConstantVectorsSse* cv )
{
    // Load 16 bytes, and mask away higher 4 bits in each byte
    __m128i v = _mm_loadu_si128( ( const __m128i* )rsi );
    v = _mm_and_si128( cv->lowNibbleMask, v );

    // Compute products, add pairwise
    v = _mm_maddubs_epi16( cv->abcd, v );

    // Final reduction step, add adjacent pairs of uint16_t lanes
    __m128i high = _mm_srli_epi32( v, 16 );
    __m128i low = _mm_blend_epi16( v, cv->zero, 0b10101010 );
    return _mm_add_epi32( high, low );
}

The code uses pmaddubsw SSSE3 instruction for multiplication and the first step of the reduction, then adds even/odd uint16_t lanes in the vector.

The above code assumes your ABCD numbers are unsigned bytes. If they are signed, you gonna need to flip order of arguments of _mm_maddubs_epi16 intrinsic and use different code for the second reduction step, _mm_slli_epi32( v, 16 ), _mm_add_epi16, _mm_srai_epi32( v, 16 )

If you have AVX2 the upgrade is trivial, replace __m128i with __m256i, and _mm_something with _mm256_something.

If the length of your input is not necessarily a multiple of 4 groups, note you gonna need special handling for the final incomplete batch of numbers. Without _mm_maskload_epi32 AVX2 instruction, here's one possible way to load incomplete vector of 4-byte groups:

__m128i loadPartial( const uint8_t* rsi, size_t rem )
{
    assert( rem != 0 && rem < 4 );
    __m128i v;
    switch( rem )
    {
    case 1:
        v = _mm_cvtsi32_si128( *(const int*)( rsi ) );
        break;
    case 2:
        v = _mm_cvtsi64_si128( *(const int64_t*)( rsi ) );
        break;
    case 3:
        v = _mm_cvtsi64_si128( *(const int64_t*)( rsi ) );
        v = _mm_insert_epi32( v, *(const int*)( rsi + 8 ), 2 );
        break;
    }
    return v;
}

P.S. Since you then gonna use these integers to index, note it only takes a few cycles of latency to extract integers from SSE vectors with _mm_extract_epi32 instruction.