How to convert 8-bit YUV420 image to RGB with Neon?

220 Views Asked by At

I'm new to Neon. I want to write a code to convert YUV420 to RGB with Neon. The pixels are 8 bit depth, I need to convert them to int32_t or float (but still clamped to a 0-255 range). However, I haven't found any Neon instruction to do that. By the way, it seems that there is no instruction to do multiplication between different types like int16 and float. Do I have to do data type converting with C code to fit Neon's requirement?

Below is the algorithm to convert YUV to RGB, the params range is [0, 255], I do the converting outside with C code:

inline std::uint8_t clamp(std::int32_t n) {
    n &= -(n >= 0);
    return n | ((255 - n) >> 31);
}

inline void yuv2rgb_i32(std::int32_t y, std::int32_t u, std::int32_t v, std::int32_t &r, std::int32_t &g, std::int32_t &b) {
    y -= 16;
    u -= 128;
    v -= 128;

    r = clamp((std::int32_t)(1.164 * y + 1.596 * v));
    g = clamp((std::int32_t)(1.164 * y - 0.392 * u - 0.813 * v));
    b = clamp((std::int32_t)(1.164 * y + 2.017 * u));
}
2

There are 2 best solutions below

3
Martin On

Here is an example how to convert planar YUV (BT.601 full swing, as used in JPEG for example):

/*------------------------------------------------------------------------------
* Global declarations (for use with linkage to C/C++ world)
* ----------------------------------------------------------------------------*/

.global neonYUV420ToRGBAFullSwing
.internal neonYUV420ToRGBAFullSwing
.type neonYUV420ToRGBAFullSwing, %function

/*------------------------------------------------------------------------------
* Definitions (internal)
* ----------------------------------------------------------------------------*/


.struct 0
yr_stackptr:
.struct yr_stackptr+8
yr_width:
.struct yr_width+4
yr_height:
.struct yr_height+4
yr_outputstride:
.struct yr_outputstride+4
yr_lumastride:
.struct yr_lumastride+4
yr_chromastride:
.struct yr_chromastride+4
yr_transbuf:
.struct yr_transbuf+16


/*------------------------------------------------------------------------------
* Text segment
* ----------------------------------------------------------------------------*/

.text


/**
* neonYUV420ToRGBAFullSwing:
* 
* \param X0 Pointer to input Y data
* \param X1 Pointer to input U data
* \param X2 Pointer to input V data
* \param X3 Pointer to output RGBA data
* \param W4 width of block to process (in output RGBA quadruplets)
* \param W5 height of block to process (in pixels)
* \param W6 row-stride (in bytes) of output data block (rgba stride)
* \param W7 row-stride (in bytes) of input Y block (luma stride)
* \param stack row-stride (in bytes) of input U/V blocks (chroma stride)
*
*
* This function converts YUV420 (planar) in BT.601 (full-swing) format to
* 32-Bit RGBA (8888) with alpha channel set to opaque.
*
* Fixed-point arithmetic is used here and RGB is computed as follows:
*
* R = [ ( 128*Y +               179*(V-128) ) >> 7 ]
* G = [ ( 128*Y -  44*(U-128) -  91*(V-128) ) >> 7 ]
* B = [ ( 128*Y + 227*(U-128)               ) >> 7 ]
*
* A is set to 255 statically.
*
*-----------------------------------------------------------------------------*/

        .align 2
neonYUV420ToRGBAFullSwing:
        /*----------------------------------------------------------------------
        * First establish proper stack-frame for facilitated gdb debugging...
        * --------------------------------------------------------------------*/
        mov x17,sp
        sub sp,sp,#(3*16+8*16+256)
        ldr w9,[x17]                                        // R9: chroma stride
        stp x29,x30,[x17,#-16]!                             // Push return address (link register) and framepointer
        stp x20,x28,[x17,#-16]!
        stp x18,x19,[x17,#-16]!
        sub x17,x17,#4*16
        st4 {v8.16b,v9.16b,v10.16b,v11.16b},[x17]
        sub x17,x17,#4*16
        st4 {v12.16b,v13.16b,v14.16b,v15.16b},[x17]
        str w4,[sp,#yr_width]
        str w5,[sp,#yr_height]
        str w6,[sp,#yr_outputstride]
        str w7,[sp,#yr_lumastride]
        str w9,[sp,#yr_chromastride]
        adr x9,yuvrgb_full_multiplier1
        ld1 {v30.4h},[x9]
        adr x9,yuvrgb_full_multiplier2
        ld1 {v31.4h},[x9]
        movi v9.8h,#128
        dup v10.8h,v30.4h[0]
        dup v11.8h,v30.4h[1]
        dup v12.8h,v30.4h[2]
        dup v13.8h,v30.4h[3]
        dup v14.8h,v31.4h[0]
        sxtw x7,w7
        add x4,x3,x6                                        // X4:  RGBA data (next row)
        add x5,x0,x7                                        // X5:  luma data (next row)
0:
        ldr w6,[sp,#yr_width]
        /*----------------------------------------------------------------------
        * Horizontal conversion loop (8 pixel per iteration, vectorized)
        *
        * X0: input Y (current row)       X1: input U (current row)
        * X2: input V (current row)       X3: output RGB (current row)
        * X4: output RGB (next row)       X5: input Y (next row)
        * W6: loop counter
        *---------------------------------------------------------------------*/
        stp x0,x1,[sp,#-16]!
        stp x2,x3,[sp,#-16]!
        stp x4,x5,[sp,#-16]!
        cmp w6,#8
        blt 3f
1:
        ld1 {v0.8b},[x0],#8                                 // V0: y7 y6 y5 y4 y3 y2 y1 y0 (8-bit)
        ld1r {v1.4s},[x1],#4                                // V1: u6 u4 u2 u0 u6 u4 u2 u0
        ld1r {v2.4s},[x2],#4                                // V2: u6 u4 u2 u0 v6 v4 v2 v0
        uxtl v3.8h,v0.8b                                    // V3: y7 y6 y5 y4 y3 y2 y1 y0 (16-bit)
        mov v4.8b,v1.8b
        mov v5.8b,v2.8b
        ld1 {v0.8b},[x5],#8                                 // V0: y7 y6 y5 y4 y3 y2 y1 y0 (8-bit) (next row)
        sub w6,w6,#8
        zip1 v1.8b,v1.8b,v4.8b                              // V1: u6 u6 u4 u4 u2 u2 u0 u0
        zip2 v2.8b,v2.8b,v5.8b                              // V2: v6 v6 v4 v4 v2 v2 v0 v0
        uxtl v4.8h,v0.8b                                    // V4: y7 y6 y5 y4 y3 y2 y1 y0 (next row, 16-bit)
        uxtl v5.8h,v1.8b                                    // V5: u7 u6 u5 u4 u3 u2 u1 u0
        uxtl v6.8h,v2.8b                                    // V6: v7 v6 v5 v4 v3 v2 v1 v0
        sub v5.8h,v5.8h,v9.8h
        sub v6.8h,v6.8h,v9.8h
        mul v3.8h,v3.8h,v10.8h                              // V3: 128*Y
        mul v4.8h,v4.8h,v10.8h                              // V4: 128*Y (next row)
        mul v0.8h,v6.8h,v11.8h                              // V0: 179*(V-128)
        mul v7.8h,v5.8h,v14.8h                              // V7: 227*(U-128)
        mul v1.8h,v5.8h,v12.8h                              // V1: -44*(U-128)
        sqadd v0.8h,v0.8h,v3.8h                             // V0: R << 7
        mla v1.8h,v6.8h,v13.8h                              // V1: -44*(U-128)-91*(V-128)
        sqadd v7.8h,v7.8h,v3.8h                             // V7: B << 7
        sqadd v1.8h,v1.8h,v3.8h                             // V1: G << 7
        sqshrun v0.8b,v0.8h,#7                              // V0: R (8-bit)
        sqshrun v1.8b,v1.8h,#7                              // V1: G (8-bit)
        cmp w6,#8
        sqshrun v2.8b,v7.8h,#7                              // V2: B (8-bit)
        movi v3.8b,#0xff
        mul v7.8h,v6.8h,v11.8h                              // V7: 179*(V-128)
        mul v6.8h,v6.8h,v13.8h                              // V6: -91*(V-128)
        st4 {v0.8b,v1.8b,v2.8b,v3.8b},[x3],#32
        sqadd v3.8h,v7.8h,v4.8h                             // V3: R << 7 (next row)
        mla v6.8h,v5.8h,v12.8h                              // V6: -44*(U-128) - 91*(V-128)
        mul v7.8h,v5.8h,v14.8h
        sqadd v6.8h,v6.8h,v4.8h                             // V6: G << 7 (next row)
        sqadd v4.8h,v4.8h,v7.8h                             // V4: B << 7 (next row)
        sqshrun v0.8b,v3.8h,#7                              // V0: R (8-bit, next row)
        movi v3.8b,#0xff
        sqshrun v1.8b,v6.8h,#7                              // V1: G (8-bit, next row)
        sqshrun v2.8b,v4.8h,#7                              // V2: B (8-bit, next row)
        st4 {v0.8b,v1.8b,v2.8b,v3.8b},[x4],#32
        bge 1b
3:
        mov w19,#255
        cmp w6,#0
        ble 5f
        /*----------------------------------------------------------------------
        * Horizontal conversion loop (scalar)
        *
        * X0: input Y (current row)          X1: input U/V (current row)
        * X2: output RGB (current row)       X3: output RGB (next row)
        * W4: loop counter                   X5: input Y (next row)
        *---------------------------------------------------------------------*/
4:
        ldrb w7,[x0],#1                                     // W7: Y
        ldrb w8,[x1],#1                                     // W8: U
        ldrb w9,[x2],#1                                     // W9: V
        mov w12,#179
        sub w8,w8,#128
        sub w9,w9,#128
        lsl w7,w7,#7                                        // W7: 128*Y
        mov w11,#227
        mul w10,w9,w12                                      // W10: 179*(V-128) (re-usable)
        mov w12,#-44
        mul w11,w8,w11                                      // W11: 227*(U-128) (re-usable)
        mul w8,w8,w12                                       // WR8:  -44*(U-128)
        mov w12,#-91
        madd w8,w9,w12,w8                                   // W8:  -44*(U-128) - 91*(V-128) (re-usable)
        add w9,w7,w10                                       // W9:  R << 7
        add w12,w7,w8                                       // W12: G << 7
        add w7,w7,w11                                       // W7:  B << 7
        asr w9,w9,#7                                        // W9:  R (unclipped)
        cmp w9,#0
        csel w9,wzr,w9,mi
        asr w12,w12,#7                                      // W12: G (unclipped)
        cmp w12,#0
        csel w12,wzr,w12,mi
        asr w7,w7,#7                                        // W7:  B (unclipped)
        cmp w7,#0
        csel w7,wzr,w7,mi
        cmp w9,w19
        csel w9,w19,w9,gt
        cmp w12,w19
        csel w12,w19,w12,gt
        cmp w7,w19
        csel w7,w19,w7,gt
        strb w9,[x3],#1
        strb w12,[x3],#1
        strb w7,[x3],#1
        strb w19,[x3],#1
        /* next pixel */
        ldrb w7,[x0],#1                                     // W7:  Y
        lsl w7,w7,#7
        add w9,w7,w10                                       // W9:  R << 7
        add w12,w7,w8                                       // W12: G << 7
        add w7,w7,w11                                       // W7:  B << 7
        asr w9,w9,#7                                        // W9:  R (unclipped)
        cmp w9,#0
        csel w9,wzr,w9,mi
        asr w12,w12,#7                                      // W12: G (unclipped)
        cmp w12,#0
        csel w12,wzr,w12,mi
        asr w7,w7,#7                                        // W7:  B (unclipped)
        cmp w7,#0
        csel w7,wzr,w7,mi
        cmp w9,w19
        csel w9,w19,w9,gt
        cmp w12,w19
        csel w12,w19,w12,gt
        cmp w7,w19
        csel w7,w19,w7,gt
        strb w9,[x3],#1
        strb w12,[x3],#1
        strb w7,[x3],#1
        strb w19,[x3],#1
        /* next row */
        ldrb w7,[x5],#1                                     // W7:  Y
        lsl w7,w7,#7
        add w9,w7,w10                                       // W9:  R << 7
        add w12,w7,w8                                       // W12: G << 7
        add w7,w7,w11                                       // W7:  B << 7
        asr w9,w9,#7                                        // W9:  R (unclipped)
        cmp w9,#0
        csel w9,wzr,w9,mi
        asr w12,w12,#7                                      // W12: G (unclipped)
        cmp w12,#0
        csel w12,wzr,w12,mi
        asr w7,w7,#7                                        // W7:  B (unclipped)
        cmp w7,#0
        csel w7,wzr,w7,mi
        cmp w9,w19
        csel w9,w19,w9,gt
        cmp w12,w19
        csel w12,w19,w12,gt
        cmp w7,w19
        csel w7,w19,w7,gt
        strb w9,[x4],#1
        strb w12,[x4],#1
        strb w7,[x4],#1
        strb w19,[x4],#1
        /* next pixel */
        ldrb w7,[x5],#1                                     // W7:  Y
        lsl w7,w7,#7
        add w9,w7,w10                                       // W9:  R << 7
        add w12,w7,w8                                       // W12: G << 7
        add w7,w7,w11                                       // W7:  B << 7
        asr w9,w9,#7                                        // W9:  R (unclipped)
        cmp w9,#0
        csel w9,wzr,w9,mi
        asr w12,w12,#7                                      // W12: G (unclipped)
        cmp w12,#0
        csel w12,wzr,w12,mi
        asr w7,w7,#7                                        // W7:  B (unclipped)
        cmp w7,#0
        csel w7,wzr,w7,mi
        cmp w9,w19
        csel w9,w19,w9,gt
        cmp w12,w19
        csel w12,w19,w12,gt
        cmp w7,w19
        csel w7,w19,w7,gt
        subs w6,w6,#2
        strb w9,[x4],#1
        strb w12,[x4],#1
        strb w7,[x4],#1
        strb w19,[x4],#1
        bgt 4b
5:
        ldp x4,x5,[sp],#16
        ldp x2,x3,[sp],#16
        ldp x0,x1,[sp],#16
        ldr w7,[sp,#yr_height]
        ldrsw x6,[sp,#yr_outputstride]
        subs w7,w7,#2
        ldrsw x8,[sp,#yr_lumastride]
        add x3,x3,x6,lsl #1
        ldrsw x9,[sp,#yr_chromastride]
        add x4,x4,x6,lsl #1
        add x0,x0,x8,lsl #1
        add x5,x5,x8,lsl #1
        add x1,x1,x9
        add x2,x2,x9
        str w7,[sp,#yr_height]
        bgt 0b
        /*----------------------------------------------------------------------
        * Epilogue and out...
        *---------------------------------------------------------------------*/
        add x17,sp,#256
        ld4 {v12.16b,v13.16b,v14.16b,v15.16b},[x17],#64
        ld4 {v8.16b,v9.16b,v10.16b,v11.16b},[x17],#64
        ldp x18,x19,[x17],#16
        ldp x20,x28,[x17],#16
        ldp x29,x30,[x17],#16
        add sp,sp,#(3*16+8*16+256)
        ret                                   // Restore frame-pointer and return




/*------------------------------------------------------------------------------
* Constants
* ----------------------------------------------------------------------------*/

        .align 4
yuvrgb_full_multiplier1:
        .short 128,179,-44,-91

        .align 4
yuvrgb_full_multiplier2:
        .short 227,0,0,0

This is a stand-alone file that can be assembled by the GNU Assembler for aarch64. Once assembled into an object file, you can simply link it to the rest of your code. From within C you can simply reference the code by:

extern void neonYUV420ToRGBAFullSwing(const uint8_t *yInput,const uint8_t *uInput,const uint8_t *vInput,uint8_t *rgbaOutput,int width,int height,int rgbaStride,int lumaStride, int chromaStride);

The code is maybe not the most optimal way to do it, but it is also not slow either. It expects the image dimensions to be even, so make sure to pad if that is not the case.

9
zuguorui On

Edit in 2024-1-27

I completed the YUV420 to ARGB888 code. The convert formual is refered from @Martin's answer. The code runs on Xiaomi 10 Ultra. ImageProxy is a proxy cpp class of android.media.Image which produced by ImageReader.

/**
 * load Y
 * from:
 * Y1 Y2 Y3 Y4 Y5 Y6 Y7 Y8
 * to:
 * Y1 Y3 Y5 Y7
 * Y2 Y4 Y6 Y8
 * */
static inline int16x8x2_t neon_load_y(uint8_t *buffer) {
    uint8x8x2_t u8_2 = vld2_u8(buffer);
    int16x8x2_t s16_2;
    // Y[even]
    s16_2.val[0] = vreinterpretq_s16_u16(vmovl_u8(u8_2.val[0]));
    // Y[odd]
    s16_2.val[1] = vreinterpretq_s16_u16(vmovl_u8(u8_2.val[1]));
    return s16_2;
}

/**
 * load U or V
 * from:
 * U1 X U2 X U3 X U4 X
 * to:
 * U1 U2 U3 U4
 * X  X  X  X
 * where X means unused byte, which is described by stride. In this case,
 * assuming the stride of U and V is 2.
 * */
static inline int16x8_t neon_load_uv(uint8_t *buffer) {
    uint8x8x2_t u8_2 = vld2_u8(buffer);
    uint8x8_t u8 = u8_2.val[0];
    int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(u8));
    return s16;
}

jobject convert_YUV_420_888_neon(JNIEnv *env, ImageProxy &image) {
    static int16x8_t _128 = vdupq_n_s16(128);
    int bitmapWidth = image.getWidth();
    int bitmapHeight = image.getHeight();

    if (bitmapClass == nullptr) {
        LOGE(TAG, "JNI object not init, init");
        initJNI(env);
    }

    jobject bitmap = env->CallStaticObjectMethod(bitmapClass, bitmapCreateMethod, bitmapWidth, bitmapHeight, argb8888Obj);

    int32_t *bitmapBuffer = nullptr;
    AndroidBitmap_lockPixels(env, bitmap, (void **)&bitmapBuffer);

    uint8_t *yBuffer, *uBuffer, *vBuffer;
    int yBufferLen, uBufferLen, vBufferLen;
    int yRowStride, uRowStride, vRowStride;
    int yPixelStride, uPixelStride, vPixelStride;

    image.getPlane(0, &yBuffer, yBufferLen, yRowStride, yPixelStride);
    image.getPlane(1, &uBuffer, uBufferLen, uRowStride, uPixelStride);
    image.getPlane(2, &vBuffer, vBufferLen, vRowStride, vPixelStride);

    assert(yPixelStride == 1);
    assert(uPixelStride == 2);
    assert(vPixelStride == 2);

    assert(image.getWidth() % 16 == 0);
    assert(image.getHeight() % 2 == 0);

    uint8_t rBuffer[8], gBuffer[8], bBuffer[8];

    chrono::time_point startTime = chrono::system_clock::now();
    int row = 0;
    while (row < image.getHeight()) {
        int col = 0;
        while (col < image.getWidth()) {
            int16x8_t u = neon_load_uv(uBuffer + row / 2 * uRowStride + col);
            // u - 128
            u = vsubq_s16(u, _128);
            int16x8_t v = neon_load_uv(vBuffer + row / 2 * vRowStride + col);
            // v - 128
            v = vsubq_s16(v, _128);

            // will not overflow
            // 44 * (u - 128)
            int16x8_t u1 = vmulq_n_s16(u, 44);
            // 227 * (u - 128)
            int16x8_t u2 = vmulq_n_s16(u, 227);
            // 179 * (v - 128)
            int16x8_t v1 = vmulq_n_s16(v, 179);
            // 91 * (v - 128)
            int16x8_t v2 = vmulq_n_s16(v, 91);
            // 44 * (u - 128) + 91 * (v - 128)
            int16x8_t c1 = vaddq_s16(u1, v2);

            // 1 line UV is used by 2 lines Y
            for (int lineOddEven = 0; lineOddEven < 2; lineOddEven++) {
                int16x8x2_t y_2 = neon_load_y(yBuffer + (row + lineOddEven) * yRowStride + col);
                for (int colOddEven = 0; colOddEven < 2; colOddEven++) {
                    int16x8_t y = y_2.val[colOddEven];
                    // y * 128
                    y = vmulq_n_s16(y, 128);

                    int16x8_t r1 = vqaddq_s16(y, v1);
                    int16x8_t g1 = vqsubq_s16(y, c1);
                    int16x8_t b1 = vqaddq_s16(y, u2);

                    r1 = vshrq_n_s16(r1, 7);
                    g1 = vshrq_n_s16(g1, 7);
                    b1 = vshrq_n_s16(b1, 7);

                    uint8x8_t r2 = vqmovun_s16(r1);
                    uint8x8_t g2 = vqmovun_s16(g1);
                    uint8x8_t b2 = vqmovun_s16(b1);

                    vst1_u8(rBuffer, r2);
                    vst1_u8(gBuffer, g2);
                    vst1_u8(bBuffer, b2);

                    for (int i = 0; i < 8; i++) {
                        uint8_t r = rBuffer[i];
                        uint8_t g = gBuffer[i];
                        uint8_t b = bBuffer[i];

                        uint32_t colorInt = (0x00FF << 24) | ((b & 0x00FF) << 16) | ((g & 0x00FF) << 8) | (r & 0x00FF);

                        bitmapBuffer[((row + lineOddEven) * bitmapWidth) + (col + 2 * i + colOddEven)] = colorInt;
                    }
                }
            }
            col += 16;
        }
        row += 2;
    }
    chrono::time_point endTime = chrono::system_clock::now();
    chrono::duration oneImageTime = endTime - startTime;
    long ms = chrono::duration_cast<chrono::milliseconds>(oneImageTime).count();
    debugIndex++;
    timeMS += ms;
    if (debugIndex >= DEBUG_LOOP) {
        long avg = timeMS / debugIndex;
        LOGD(TAG, "convert neon raw, %d images avg cost %d ms, image size = [%d, %d]", debugIndex, (int)avg, bitmapWidth, bitmapHeight);
        debugIndex = 0;
        timeMS = 0;
    }
    AndroidBitmap_unlockPixels(env, bitmap);
    return bitmap;
}

Run with image size 3840*2160.

When compiled with -O0 option, it costs 160+ms while C version costs 210+ms.

When compiled with -O1 option, it costs 30+ms while the traditional C version costs 60+ms.