Setting/getting 1-bits of __m256i vector from integer array of bit positions

128 Views Asked by At

Setting bits:

Given an array int inds[N], where each inds[i] is a 1-bit position in [0, 255] range (and all inds[i] are sorted and unique), I need to set corresponding bits of __m256i to 1.

Is there a better way than what I do below:

alignas(32) uint64_t buf[4] = {0};

for (int i = 0; i < N; ++i) {
    int ind = inds[i];
    buf[ind / 64] |= 1ul << (ind % 64);
}
auto r = _mm256_load_si256((__m256i*)buf);

Getting bits:

In the opposite operation, I need to compute product of double values at bit-1 positions. I.e., given double const sizes[256] compute product of some of them (at positions given by __m256i mask).

inline
double size (__m256i r, double const sizes[256])
{
    alignas(16) uint64_t buf[4];
    _mm256_store_si256((__m256i*)buf, r);

    double s[4] = {1.0, 1.0, 1.0, 1.0};

    // __builtin_ctzl(i) gives next position
    // and i &= i - 1 clears that bit

    for (; buf[0] != 0; buf[0] &= buf[0] - 1)
        s[0] *= sizes[__builtin_ctzl(buf[0]) + 0 * 64];

    for (; buf[1] != 0; buf[1] &= buf[1] - 1)
        s[1] *= sizes[__builtin_ctzl(buf[1]) + 1 * 64];

    for (; buf[2] != 0; buf[2] &= buf[2] - 1)
        s[2] *= sizes[__builtin_ctzl(buf[2]) + 2 * 64];

    for (; buf[3] != 0; buf[3] &= buf[3] - 1)
        s[3] *= sizes[__builtin_ctzl(buf[3]) + 3 * 64];

    return s[0] * s[1] * s[2] * s[3];
}

Same question: better way to do it?

1

There are 1 best solutions below

10
harold On

It's possible to do this with AVX512, and more efficient than the scalar approach in some cases, depending on N.

There is something else though, the scalar approach has a problem that can be fixed: a loop-carried dependency through memory. For example GCC compiles the code like this, (relevant part extracted)

.L3:
    movzx   eax, BYTE PTR [rdi]
    mov     rdx, r8
    add     rdi, 1
    mov     rcx, rax
    shr     rax, 6
    sal     rdx, cl
    or      QWORD PTR [rsp-32+rax*8], rdx
    cmp     rsi, rdi
    jne     .L3

That or is loading/storing the same memory location in (most) successive loop iterations. This can be avoided by writing separate loops for each chunk of the result,

__m256i set_indexed_bits2(uint8_t* indexes, size_t N)
{
    alignas(32) uint64_t buf[4] = { 0 };
    if (N < 256)
        indexes[N] = 255;
    size_t i = 0;
    while (indexes[i] < 64)
        buf[0] |= 1ull << indexes[i++];
    while (indexes[i] < 128)
        buf[1] |= 1ull << indexes[i++];
    while (indexes[i] < 192)
        buf[2] |= 1ull << indexes[i++];
    while (i < N)
        buf[3] |= 1ull << indexes[i++];
    return _mm256_load_si256((__m256i*)buf);
}

At the source level that looks like there is still a dependency through memory, but when it's written this way (where the index in the array is constant) compilers are likely to apply an optimization where they temporarily use a register for buf[0] and so on, for example here's an excerpt from what GCC made of that (which is fairly representative of what other compilers do too):

.L15:
    add     rax, 1
    mov     r11, rdi
    sal     r11, cl
    movzx   ecx, BYTE PTR [rdx+rax]
    or      rsi, r11
    cmp     cl, -65
    jbe     .L15
    mov     QWORD PTR [rsp-16], rsi

Much better (though GCC missed the opportunity to use bts with a register destination here, which is efficient unlike the version with a memory destination). In fact more than twice as good in my tests, but that will depend on N and other factors.

And here are the AVX512 hacks for good measure. On my PC (rocket lake), this is faster (in the sense of throughput being higher, I did not test latency) than the improved scalar code for some N, with Peter's suggestions now around 16 or more, not bad. Conversion to AVX2 seems possible, but that would make the threshold where it begins to be worth it higher.

__m512i indexes_to_bits64(__m512i indexes, __mmask64 valids)
{
    // make valid bytes in the range 0..63, make invalid bytes out-of-range
    indexes = _mm512_and_epi64(indexes, _mm512_set1_epi8(63));
    indexes = _mm512_mask_blend_epi8(valids, _mm512_set1_epi8(-1), indexes);
    __m512i one = _mm512_set1_epi64(1);
    __mmask64 m = 0x0101010101010101;
    __m512i b0 = _mm512_sllv_epi64(one, _mm512_cvtepu8_epi64(_mm512_castsi512_si128(indexes)));
    __m512i b1 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(8, 9, 10, 11, 12, 13, 14, 15), indexes));
    __m512i b2 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(16, 17, 18, 19, 20, 21, 22, 23), indexes));
    __m512i b3 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(24, 25, 26, 27, 28, 29, 30, 31), indexes));
    indexes = _mm512_shuffle_i64x2(indexes, indexes, _MM_SHUFFLE(1, 0, 3, 2));
    __m512i b4 = _mm512_sllv_epi64(one, _mm512_cvtepu8_epi64(_mm512_castsi512_si128(indexes)));
    __m512i b5 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(8, 9, 10, 11, 12, 13, 14, 15), indexes));
    __m512i b6 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(16, 17, 18, 19, 20, 21, 22, 23), indexes));
    __m512i b7 = _mm512_sllv_epi64(one, _mm512_maskz_permutexvar_epi8(m, _mm512_setr_epi64(24, 25, 26, 27, 28, 29, 30, 31), indexes));
    __m512i b012 = _mm512_ternarylogic_epi64(b0, b1, b2, 0xFE);
    __m512i b345 = _mm512_ternarylogic_epi64(b3, b4, b5, 0xFE);
    __m512i b67 = _mm512_or_epi64(b6, b7);
    return _mm512_ternarylogic_epi64(b012, b345, b67, 0xFE);
}

__m256i set_indexed_bits_avx512(uint8_t* indexes, int N)
{
    // load values 0..63 into one chunk,
    // 64..127 in the next chunk
    // 128..191 in the third chunk
    // 192..255 in the last chunk
    // this automatically expanded based on bits 7 and 6
    __m512i chunk0 = _mm512_loadu_epi8(indexes);
    __mmask64 valids0 = _mm512_cmple_epu8_mask(chunk0, _mm512_set1_epi8(63));
    int chunk0_count = std::countr_one(valids0);
    valids0 = _bzhi_u64(valids0, N);
    __m512i chunk1 = _mm512_loadu_epi8(indexes + chunk0_count);
    __mmask64 valids1 = _mm512_cmple_epu8_mask(chunk1, _mm512_set1_epi8(127));
    int chunk1_count = std::countr_one(valids1);
    valids1 = _bzhi_u64(valids1, std::max(0, N - chunk0_count));
    __m512i chunk2 = _mm512_loadu_epi8(indexes + chunk0_count + chunk1_count);
    __mmask64 valids2 = _mm512_cmple_epu8_mask(chunk2, _mm512_set1_epi8(191));
    int chunk2_count = std::countr_one(valids2);
    valids2 = _bzhi_u64(valids2, std::max(0, N - chunk0_count - chunk1_count));
    __m512i chunk3 = _mm512_loadu_epi8(indexes + chunk0_count + chunk1_count + chunk2_count);
    __mmask64 valids3 = _bzhi_u64(-1ULL, std::max(0, N - chunk0_count - chunk1_count - chunk2_count));
    // 1 << bottom 6 bits
    chunk0 = indexes_to_bits64(chunk0, valids0);
    chunk1 = indexes_to_bits64(chunk1, valids1);
    chunk2 = indexes_to_bits64(chunk2, valids2);
    chunk3 = indexes_to_bits64(chunk3, valids3);
    // interleave and reduce horizontally
    __m512i chunk01 = _mm512_or_epi64(
        _mm512_unpacklo_epi64(chunk0, chunk1),
        _mm512_unpackhi_epi64(chunk0, chunk1));
    __m512i chunk23 = _mm512_or_epi64(
        _mm512_unpacklo_epi64(chunk2, chunk3),
        _mm512_unpackhi_epi64(chunk2, chunk3));
    __m256i chunk01_2 = _mm256_or_si256(_mm512_castsi512_si256(chunk01), _mm512_extracti64x4_epi64(chunk01, 1));
    __m256i chunk23_2 = _mm256_or_si256(_mm512_castsi512_si256(chunk23), _mm512_extracti64x4_epi64(chunk23, 1));
    __m128i chunk01_3 = _mm_or_si128(_mm256_castsi256_si128(chunk01_2), _mm256_extracti128_si256(chunk01_2, 1));
    __m128i chunk23_3 = _mm_or_si128(_mm256_castsi256_si128(chunk23_2), _mm256_extracti128_si256(chunk23_2, 1));
    return _mm256_inserti128_si256(_mm256_castsi128_si256(chunk01_3), chunk23_3, 1);
}