Find all subsets of a mask which are divisible by a certain number

219 Views Asked by At

All subsets of of the 4-bit number 1101 are 0000, 0001, 0100, 0101, 1000, 1001, 1100, 1011. All subsets of this mask which are divisible by 2 are 0000, 0100, 1000, 1100.

Given a 64-bit mask M and a 64-bit bit integer P, how do I iterate over all subsets of M which are divisible by P?

To iterate over subsets of a bit mask, I can do

uint64_t superset = ...;
uint64_t subset = 0;
do {
    print(subset);
    subset = (subset - superset) & superset;
} while (subset != 0);

If M is ~0 I can just start with 0 and keep adding P to iterate over all multiples of P. If P is a power of two I can just do M &= ~(P - 1) to chop off bits which are never going to be set.

But if I have none of the constraints above, do I have a better shot than naively checking each and every subset for divisibility by P? This naive algorithm on average to get the next subset which is divisible by P takes O(P) operations. Can I do better than O(P)?

1

There are 1 best solutions below

0
Davislor On

A Parallel Algorithm

There are inputs for which it is vastly more efficient to check the multiples of the factor than the subsets of the mask, and inputs where it’s the other way around. For example, when M is 0xFFFFFFFFFFFFFFFF and P is 0x4000000000000000, checking the three multiples of P is nigh-instantaneous, but even if you could crunch and check a billion subsets of M each second, enumerating them all would take thirty years. The optimization of finding only subsets greater than or equal to P would only cut that to four years.

However, there is a strong reason to enumerate and check the multiples of P instead of the subsets of M: parallelism. I want to emphasize, because of incorrect comments on this code elsewhere: the algorithm in the OP is inherently sequential, because each value of subset uses the previous value of subset. It cannot run until all the lower subsets have already been calculated. It cannot be vectorized to use AVX registers or similar. You cannot load four values into an AVX2 register and run SIMD instructions on them, because you would need to calculate the first value to initialize the second element, the second to initialize the third, and all three to initialize the final one, and then you are back to computing only one value at a time. It cannot be split between worker threads on different CPU cores either, which is not the same thing. (The accepted answer can be modified to do the latter, but not the former without a total refactoring.) You cannot divide the workload into subsets 0 to 63, subsets 64 to 127, and so on, and have different threads work on each in parallel, because you cannot start on the sixty-fourth subset until you know what the sixty-third subset is, for which you need the sixty-second, and so on.

If you take nothing else away from this, I highly recommend that you try this code out on Godbolt with full optimizations enabled, and see for yourself that it compiles to sequential code. If you’re familiar with OpenMP, try adding #pragma omp simd and #pramga omp parallel directives and see what happens. The problem isn’t with the compiler, it’s that the algorithm is inherently sequential. But seeing what real compilers do should at least convince you that compilers in the year 2023 are not able to vectorize code like this.

For reference, here is what Clang 16 does with find:

Find:                                   # @Find
        push    r15
        push    r14
        push    r12
        push    rbx
        push    rax
        mov     rbx, rdi
        cmp     rdi, rsi
        jne     .LBB1_1
.LBB1_6:
        lea     rdi, [rip + .L.str]
        mov     rsi, rbx
        xor     eax, eax
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        jmp     printf@PLT                      # TAILCALL
.LBB1_1:
        mov     r14, rdx
        mov     r15, rsi
        jmp     .LBB1_2
.LBB1_5:                                #   in Loop: Header=BB1_2 Depth=1
        imul    r12, r14
        add     r15, r12
        cmp     r15, rbx
        je      .LBB1_6
.LBB1_2:                                # =>This Inner Loop Header: Depth=1
        cmp     r15, rbx
        ja      .LBB1_7
        mov     rax, r15
        xor     rax, rbx
        blsi    r12, rax
        test    r12, rbx
        je      .LBB1_5
        mov     rdi, rbx
        sub     rdi, r12
        mov     rsi, r15
        mov     rdx, r14
        call    Find
        jmp     .LBB1_5
.LBB1_7:
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        ret

Enumerate and Check the Multiples Instead of the Subsets

In addition to having more parallelism, this has several advantages in speed:

  • Finding the successor, or (i+4)*p given i*p to use this on a vector of four elements, can be strength-reduced to a single addition.
  • Testing whether a factor is a subset is a single and operation, whereas testing whether a subset is a factor requires a % operation, which most CPUs do not have as a native instruction and is always the slowest ALU operation even when it is there.

So, a version of this code that uses both multi-threading and SIMD for speed-up:

#include <assert.h>
#include <omp.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>


typedef uint_fast32_t word;

/* Sets each element results[i], where i <= mask/factor, to true if factor*i
 * is a subset of the mask, false otherwise.  The results array MUST have at
 * least (mask/factor + 1U) elements.  The capacity of results in elements is
 * required and checked, just in case.
 *
 * Returns a pointer to the results.
 */
static bool* check_multiples( const word mask,
                              const word factor,
                              const size_t n,
                              bool results[n] )
{
    const word end = mask/factor;
    const word complement = ~mask;
    assert(&results);
    assert(n > end);

    #pragma omp parallel for simd schedule(static)
    for (word i = 0; i <= end; ++i) {
        results[i] = (factor*i & complement) == 0;
    }

    return results;
}

/* Replace these with non-constants so that the compiler actually
 * actually instantiates the function:
 */
/*
#define MASK 0xA0A0UL
#define FACTOR 0x50UL
#define NRESULTS (MASK/FACTOR + 1U)
 */
extern const word MASK, FACTOR;
#define NRESULTS 1024UL

int main(void)
{
    bool are_subsets[NRESULTS] = {false};
    (void)check_multiples(MASK, FACTOR, NRESULTS, are_subsets);

    for (word i = 0; i < NRESULTS; ++i) {
        if (are_subsets[i]) {
            const unsigned long long multiple = (unsigned long long)FACTOR*i;
            printf("%llx ", multiple);
            assert((multiple & MASK) == multiple && (multiple & ~MASK) == 0U);
        }
    }

    return EXIT_SUCCESS;
}

The inner loop of check_multiples compiles, on ICX 2022, to:

.LBB1_5:                                # =>This Inner Loop Header: Depth=1
        vpmullq         ymm15, ymm1, ymm0
        vpmullq         ymm16, ymm2, ymm0
        vpmullq         ymm17, ymm3, ymm0
        vpmullq         ymm18, ymm4, ymm0
        vpmullq         ymm19, ymm5, ymm0
        vpmullq         ymm20, ymm6, ymm0
        vpmullq         ymm21, ymm7, ymm0
        vpmullq         ymm22, ymm8, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm21, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm20, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm16, ymm9
        vptestnmq       k3, ymm15, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi], ymm15
        vpaddq          ymm15, ymm11, ymm7
        vpaddq          ymm16, ymm6, ymm11
        vpaddq          ymm17, ymm5, ymm11
        vpaddq          ymm18, ymm4, ymm11
        vpaddq          ymm19, ymm3, ymm11
        vpaddq          ymm20, ymm2, ymm11
        vpaddq          ymm21, ymm1, ymm11
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm11
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 32], ymm15
        vpaddq          ymm15, ymm12, ymm7
        vpaddq          ymm16, ymm6, ymm12
        vpaddq          ymm17, ymm5, ymm12
        vpaddq          ymm18, ymm4, ymm12
        vpaddq          ymm19, ymm3, ymm12
        vpaddq          ymm20, ymm2, ymm12
        vpaddq          ymm21, ymm1, ymm12
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm12
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 64], ymm15
        vpaddq          ymm15, ymm13, ymm7
        vpaddq          ymm16, ymm6, ymm13
        vpaddq          ymm17, ymm5, ymm13
        vpaddq          ymm18, ymm4, ymm13
        vpaddq          ymm19, ymm3, ymm13
        vpaddq          ymm20, ymm2, ymm13
        vpaddq          ymm21, ymm1, ymm13
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm13
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 96], ymm15
        vpaddq          ymm8, ymm8, ymm14
        vpaddq          ymm7, ymm14, ymm7
        vpaddq          ymm6, ymm14, ymm6
        vpaddq          ymm5, ymm14, ymm5
        vpaddq          ymm4, ymm14, ymm4
        vpaddq          ymm3, ymm14, ymm3
        vpaddq          ymm2, ymm14, ymm2
        vpaddq          ymm1, ymm14, ymm1
        sub             rsi, -128
        add             rdi, -4
        jne             .LBB1_5

I encourage you to try your variations on the algorithm in this compiler, under the same settings, and see what happens. If you think it should be possible to generate vectorized code on the subsets as good as that, you should get some practice.

A Possible Improvement

The number of candidates to check could get extremely large, but one way to limit it is to also compute the multiplicative inverse of P, and use that if it is better.

Every value of P decomposes into 2ⁱ · Q, where Q is odd. Since Q and 2⁶⁴ are coprime, Q will have a modular multiplicative inverse, Q', whose product QQ' = 1 (mod 2⁶⁴). You can find this with the extended Euclidean algorithm (but not the method I proposed here initially).

This is useful for optimizing the algorithm because, for many values of P, Q' < P. If m is a solution, m = nP for some integer n. Multiply both sides by Q', and Q'Pm = 2ⁱ · m = Q'n. This means we can enumerate (with a bit of extra logic to make sure they have enough trailing zero bits) the multiples of Q' or of P. Note that, since Q' is odd, it is not necessary to check all multiples of Q'; if the constant in front of m is 4, for example, you need only check the products of 4·_Q'_.