How to combine elements in Karatsuba multiplication

162 Views Asked by At

I'm trying to implement multiplication using the Karatsuba algorithm.
The usual bit masking and shifting approach implemented in mul_kar2() works.

In mul_kar2_au(), I'm struggling with combining the individual components into a "single number" using a union, but can be an array. When I combine them, like in math school, and cast on 64-bit integer - it works (line commented with //). When I 'put them together' it doesn't. I know I have to emulate mathematical operation, but don't know how.
I don't want to use 64-bit integer as if it was on 32-bit processor/microcontroller, and also I want to extend it into 128-bit multiplication which is not present in arm64, not in hardware, so I have nowhere to cast it on.

typedef union
{
    uint32_t au32[2];
    uint16_t au16[4];
} au64_u;

au64_u mul_kar2_au(au64_u aru64)
{
    uint32_t ha, hb, la, lb, z2, z1, z0;
    int m = sizeof(uint32_t)*CHAR_BIT, m_2 = m/2;
    au64_u res;
    uint32_t a = aru64.au32[1];
    uint32_t b = aru64.au32[0];

    ha = a >> m_2;
    hb = b >> m_2;
    la = (uint16_t)a;
    lb = (uint16_t)b;
    z2 = ha * hb;
    z0 = la * lb;
    z1 = (ha + la) * (hb + lb) - z2 - z0;

    //uint64_t res = ((uint64_t)z2 << (m2 * 2)) + ((uint64_t)z1 << m2) + (uint64_t)z0;
    aru64.au16[0] = z0;
    aru64.au16[1] = (uint16_t)z1;
    aru64.au16[2] = z1 >> m_2;
    aru64.au16[3] = z2;

    return aru64;
}

And this is a result:

m=32, m_2=16
a=1458354473 (0x56ecb929)
b=1458354473 (0x56ecb929)

y=a*b:
2126797768919107729     0x1d83e74d75844891

t=mul_kar2(a, b):
2126797768919107729     0x1d83e74d75844891

t=mul_kar2_au(a, b):
7606718021055826065     0x69907dbcef984891

As can be seen only the last (least significant) 16 bits are correct.

The whole program can be found here.

If it's math question I will put it in correct department.

3

There are 3 best solutions below

4
greybeard On

In math school, you add properly aligned parts.
Using A,a and B,b for single character factor part designators:
z0: abab abab
z1: (A+a)(B+b)
z2: ABAB ABAB
summing:
ABAB ABAB abab abab
+ (A+a)(B+b)

aru64.au16[1] = (uint16_t)z1;
aru64.au16[2] = z1 >> m_2;

ignores the upper half of z0 and the lower one of z2.
For one way to avoid problems with the size of z1 see the bottom of en.wikipedia's paragraph on the Implementation of Karatsuba's multiplication algorithm.

2
greybeard On

Answering a different question, actually:
How to get the full product of a 32×32 multiplication using 32-bit arithmetic exclusively?
Caveat: tried, but not seriously tested

/** full 64 bit product from multiplying two uint32_t factors
 *  using 32-bit arithmetic exclusively, assuming little endian storage order */
static const int HALF_BIT = 16;  // or 2 * CHAR_BIT - whatever
typedef union {
    uint32_t au32[2];
    uint16_t au16[4];
} au64_u;

/** return the product of the factors in aru64.au32 assuming little endian order */
au64_u product(au64_u aru64)
{
    au64_u result;
    uint32_t  // good thing multiplication is commutative
        a  = aru64.au32[1], b  = aru64.au32[0],
        la = (uint16_t)a,   lb = (uint16_t)b,
        ha = aru64.au16[3], hb = aru64.au16[1],
        low = la * lb, high = ha * hb,
        mid1 = ha * lb, mid = la * hb;
//  result.au32[0] = low + (mid1+mid << HALF_BIT);  // for expensive multiplication
    result.au32[0] = a * b;  // this includes the lower halves of ha*lb & la*hb
    // two k bit numbers can be added into a "k*k" product without a 2k carry
    mid1 += low >> HALF_BIT;
    // this addition needs to include all the low order bits
    //  to get the carry into the upper half right - and may produce a carry out
    mid += mid1;
    if (mid < mid1)
        high += 1<<HALF_BIT;
    result.au32[1] = high + (mid >> HALF_BIT);

    return result;
}

— no recursion, no Toom‐Cook.
For a not so different way to handle carries see Ben Voigt's answer on how to get the upper half of a 64×64 bit product.

1
tansy On

Multiplication of a two 32-bit integers into 64-bit product, using long multiplication of 4 16-bit 'digits'.

typedef union
    {
    u32 au32[2];
    u16 au16[4];
    } au64_u;

au64_u mul_2x32_long(au64_u aru64)
{
    au64_u result;
    uint32_t \
        la = aru64.au16[0], ha = aru64.au16[1],
        lb = aru64.au16[2], hb = aru64.au16[3];
    // high << 32  + mid << 16             + low << 0;
    // ha*hb << 32 + (ha*bl + la*hb) << 16 + la*lb << 0;
    uint32_t \
    word_bits = sizeof(uint32_t)*CHAR_BIT/2,
    hcarry = 0, lcarry = 0,
    high = ha*hb,
    mid = ha*lb + la*hb,
    low = la*lb,
    hhalf, lhalf;

    if (mid < ha*lb)
        hcarry = 1;
    lhalf = low + (mid << word_bits);
    if (lhalf<low)
        lcarry = 1;
    hhalf = high + (mid >> word_bits) + (hcarry << word_bits) + lcarry;
    result.au32[0] = lhalf;
    result.au32[1] = hhalf;

    return result;
}

It solves the problem of multiplication, though it's not Karatsuba algorithm.

It's slightly slower (~6%) than this answer and longer (28 lines vs 25 lines) but uses 4 multiplications. Just like in definition of long multiplication.