C语言 高效的 4x4 矩阵乘法(C 与汇编)

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/18499971/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-09-02 07:18:56  来源:igfitidea点击:

Efficient 4x4 matrix multiplication (C vs assembly)

coptimizationassemblyssematrix-multiplication

提问by Krzysztof Abramowicz

I'm looking for a faster and trickier way to multiply two 4x4 matrices in C. My current research is focused on x86-64 assembly with SIMD extensions. So far, I've created a function witch is about 6x faster than a naive C implementation, which has exceeded my expectations for the performance improvement. Unfortunately, this stays true only when no optimization flags are used for compilation (GCC 4.7). With -O2, C becomes faster and my effort becomes meaningless.

我正在寻找一种更快、更棘手的方法来在 C 中将两个 4x4 矩阵相乘。我目前的研究重点是带有 SIMD 扩展的 x86-64 程序集。到目前为止,我已经创建了一个比简单的 C 实现快 6 倍的函数,这超出了我对性能改进的预期。不幸的是,这仅在没有使用优化标志进行编译时才成立(GCC 4.7)。随着-O2,C 变得更快,我的努力变得毫无意义。

I know that modern compilers make use of complex optimization techniques to achieve an almost perfect code, usually faster than an ingenious piece of hand-crafed assembly. But in a minority of performance-critical cases, a human may try to fight for clock cycles with the compiler. Especially, when some mathematics backed with a modern ISA can be explored (as it is in my case).

我知道现代编译器利用复杂的优化技术来实现几乎完美的代码,通常比巧妙的手工组装更快。但是在少数对性能至关重要的情况下,人们可能会尝试与编译器争夺时钟周期。特别是,当可以探索一些由现代 ISA 支持的数学时(就像我的情况一样)。

My function looks as follows (AT&T syntax, GNU Assembler):

我的函数如下(AT&T 语法,GNU 汇编程序):

    .text
    .globl matrixMultiplyASM
    .type matrixMultiplyASM, @function
matrixMultiplyASM:
    movaps   (%rdi), %xmm0    # fetch the first matrix (use four registers)
    movaps 16(%rdi), %xmm1
    movaps 32(%rdi), %xmm2
    movaps 48(%rdi), %xmm3
    xorq %rcx, %rcx           # reset (forward) loop iterator
.ROW:
    movss (%rsi), %xmm4       # Compute four values (one row) in parallel:
    shufps 
void matrixMultiplyNormal(mat4_t *mat_a, mat4_t *mat_b, mat4_t *mat_r) {
    for (unsigned int i = 0; i < 16; i += 4)
        for (unsigned int j = 0; j < 4; ++j)
            mat_r->m[i + j] = (mat_b->m[i + 0] * mat_a->m[j +  0])
                            + (mat_b->m[i + 1] * mat_a->m[j +  4])
                            + (mat_b->m[i + 2] * mat_a->m[j +  8])
                            + (mat_b->m[i + 3] * mat_a->m[j + 12]);
}
x0, %xmm4, %xmm4 # 4x 4FP mul's, 3x 4FP add's 6x mov's per row, mulps %xmm0, %xmm4 # expressed in four sequences of 5 instructions, movaps %xmm4, %xmm5 # executed 4 times for 1 matrix multiplication. addq
    .text
    .align 32                           # 1. function entry alignment
    .globl matrixMultiplyASM            #    (for a faster call)
    .type matrixMultiplyASM, @function
matrixMultiplyASM:
    movaps   (%rdi), %xmm0
    movaps 16(%rdi), %xmm1
    movaps 32(%rdi), %xmm2
    movaps 48(%rdi), %xmm3
    movq , %rcx                      # 2. loop reversal
1:                                      #    (for simpler exit condition)
    movss (%rsi, %rcx), %xmm4           # 3. extended address operands
    shufps 
void M4x4_SSE(float *A, float *B, float *C) {
    __m128 row1 = _mm_load_ps(&B[0]);
    __m128 row2 = _mm_load_ps(&B[4]);
    __m128 row3 = _mm_load_ps(&B[8]);
    __m128 row4 = _mm_load_ps(&B[12]);
    for(int i=0; i<4; i++) {
        __m128 brod1 = _mm_set1_ps(A[4*i + 0]);
        __m128 brod2 = _mm_set1_ps(A[4*i + 1]);
        __m128 brod3 = _mm_set1_ps(A[4*i + 2]);
        __m128 brod4 = _mm_set1_ps(A[4*i + 3]);
        __m128 row = _mm_add_ps(
                    _mm_add_ps(
                        _mm_mul_ps(brod1, row1),
                        _mm_mul_ps(brod2, row2)),
                    _mm_add_ps(
                        _mm_mul_ps(brod3, row3),
                        _mm_mul_ps(brod4, row4)));
        _mm_store_ps(&C[4*i], row);
    }
}
, %xmm4, %xmm4 # (faster than pointer calculation) mulps %xmm0, %xmm4 movaps %xmm4, %xmm5 movss 4(%rsi, %rcx), %xmm4 shufps
A1 A2 A3 A4        W1 W2 W3 W4
B1 B2 B3 B4        X1 X2 X3 X4
C1 C2 C3 C4    *   Y1 Y2 Y3 Y4
D1 D2 D3 D4        Z1 Z2 Z3 Z4
, %xmm4, %xmm4 mulps %xmm1, %xmm4 addps %xmm4, %xmm5 movss 8(%rsi, %rcx), %xmm4 shufps
dot(A,?1) dot(A,?2) dot(A,?3) dot(A,?4)
dot(B,?1) dot(B,?2) dot(B,?3) dot(B,?4)
dot(C,?1) dot(C,?2) dot(C,?3) dot(C,?4)
dot(D,?1) dot(D,?2) dot(D,?3) dot(D,?4)
, %xmm4, %xmm4 mulps %xmm2, %xmm4 addps %xmm4, %xmm5 movss 12(%rsi, %rcx), %xmm4 shufps
A1 A2 A3 A4        W1 X1 Y1 Z1
B1 B2 B3 B4        W2 X2 Y2 Z2
C1 C2 C3 C4    *   W3 X3 Y3 Z3
D1 D2 D3 D4        W4 X4 Y4 Z4
, %xmm4, %xmm4 mulps %xmm3, %xmm4 addps %xmm4, %xmm5 movaps %xmm5, (%rdx, %rcx) subq , %rcx # one 'sub' (vs 'add' & 'cmp') jge 1b # SF=OF, idiom: jump if positive ret
x4, %rsi movss (%rsi), %xmm4 # movss + shufps comprise _mm_set1_ps intrinsic shufps
struct MATRIX {
    union {
        float  f[4][4];
        __m128 m[4];
        __m256 n[2];
    };
};
MATRIX myMultiply(MATRIX M1, MATRIX M2) {
    // Perform a 4x4 matrix multiply by a 4x4 matrix 
    // Be sure to run in 64 bit mode and set right flags
    // Properties, C/C++, Enable Enhanced Instruction, /arch:AVX 
    // Having MATRIX on a 32 byte bundry does help performance
    MATRIX mResult;
    __m256 a0, a1, b0, b1;
    __m256 c0, c1, c2, c3, c4, c5, c6, c7;
    __m256 t0, t1, u0, u1;

    t0 = M1.n[0];                                                   // t0 = a00, a01, a02, a03, a10, a11, a12, a13
    t1 = M1.n[1];                                                   // t1 = a20, a21, a22, a23, a30, a31, a32, a33
    u0 = M2.n[0];                                                   // u0 = b00, b01, b02, b03, b10, b11, b12, b13
    u1 = M2.n[1];                                                   // u1 = b20, b21, b22, b23, b30, b31, b32, b33

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0));        // a0 = a00, a00, a00, a00, a10, a10, a10, a10
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0));        // a1 = a20, a20, a20, a20, a30, a30, a30, a30
    b0 = _mm256_permute2f128_ps(u0, u0, 0x00);                      // b0 = b00, b01, b02, b03, b00, b01, b02, b03  
    c0 = _mm256_mul_ps(a0, b0);                                     // c0 = a00*b00  a00*b01  a00*b02  a00*b03  a10*b00  a10*b01  a10*b02  a10*b03
    c1 = _mm256_mul_ps(a1, b0);                                     // c1 = a20*b00  a20*b01  a20*b02  a20*b03  a30*b00  a30*b01  a30*b02  a30*b03

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1));        // a0 = a01, a01, a01, a01, a11, a11, a11, a11
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1));        // a1 = a21, a21, a21, a21, a31, a31, a31, a31
    b0 = _mm256_permute2f128_ps(u0, u0, 0x11);                      // b0 = b10, b11, b12, b13, b10, b11, b12, b13
    c2 = _mm256_mul_ps(a0, b0);                                     // c2 = a01*b10  a01*b11  a01*b12  a01*b13  a11*b10  a11*b11  a11*b12  a11*b13
    c3 = _mm256_mul_ps(a1, b0);                                     // c3 = a21*b10  a21*b11  a21*b12  a21*b13  a31*b10  a31*b11  a31*b12  a31*b13

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2));        // a0 = a02, a02, a02, a02, a12, a12, a12, a12
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2));        // a1 = a22, a22, a22, a22, a32, a32, a32, a32
    b1 = _mm256_permute2f128_ps(u1, u1, 0x00);                      // b0 = b20, b21, b22, b23, b20, b21, b22, b23
    c4 = _mm256_mul_ps(a0, b1);                                     // c4 = a02*b20  a02*b21  a02*b22  a02*b23  a12*b20  a12*b21  a12*b22  a12*b23
    c5 = _mm256_mul_ps(a1, b1);                                     // c5 = a22*b20  a22*b21  a22*b22  a22*b23  a32*b20  a32*b21  a32*b22  a32*b23

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3));        // a0 = a03, a03, a03, a03, a13, a13, a13, a13
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3));        // a1 = a23, a23, a23, a23, a33, a33, a33, a33
    b1 = _mm256_permute2f128_ps(u1, u1, 0x11);                      // b0 = b30, b31, b32, b33, b30, b31, b32, b33
    c6 = _mm256_mul_ps(a0, b1);                                     // c6 = a03*b30  a03*b31  a03*b32  a03*b33  a13*b30  a13*b31  a13*b32  a13*b33
    c7 = _mm256_mul_ps(a1, b1);                                     // c7 = a23*b30  a23*b31  a23*b32  a23*b33  a33*b30  a33*b31  a33*b32  a33*b33

    c0 = _mm256_add_ps(c0, c2);                                     // c0 = c0 + c2 (two terms, first two rows)
    c4 = _mm256_add_ps(c4, c6);                                     // c4 = c4 + c6 (the other two terms, first two rows)
    c1 = _mm256_add_ps(c1, c3);                                     // c1 = c1 + c3 (two terms, second two rows)
    c5 = _mm256_add_ps(c5, c7);                                     // c5 = c5 + c7 (the other two terms, second two rose)

                                                                    // Finally complete addition of all four terms and return the results
    mResult.n[0] = _mm256_add_ps(c0, c4);       // n0 = a00*b00+a01*b10+a02*b20+a03*b30  a00*b01+a01*b11+a02*b21+a03*b31  a00*b02+a01*b12+a02*b22+a03*b32  a00*b03+a01*b13+a02*b23+a03*b33
                                                //      a10*b00+a11*b10+a12*b20+a13*b30  a10*b01+a11*b11+a12*b21+a13*b31  a10*b02+a11*b12+a12*b22+a13*b32  a10*b03+a11*b13+a12*b23+a13*b33
    mResult.n[1] = _mm256_add_ps(c1, c5);       // n1 = a20*b00+a21*b10+a22*b20+a23*b30  a20*b01+a21*b11+a22*b21+a23*b31  a20*b02+a21*b12+a22*b22+a23*b32  a20*b03+a21*b13+a22*b23+a23*b33
                                                //      a30*b00+a31*b10+a32*b20+a33*b30  a30*b01+a31*b11+a32*b21+a33*b31  a30*b02+a31*b12+a32*b22+a33*b32  a30*b03+a31*b13+a32*b23+a33*b33
    return mResult;
}
x0, %xmm4, %xmm4 # mulps %xmm1, %xmm4 addps %xmm4, %xmm5 addq ##代码##x4, %rsi # manual pointer arithmetic simplifies addressing movss (%rsi), %xmm4 shufps ##代码##x0, %xmm4, %xmm4 mulps %xmm2, %xmm4 # actual computation happens here addps %xmm4, %xmm5 # addq ##代码##x4, %rsi movss (%rsi), %xmm4 # one mulps operand fetched per sequence shufps ##代码##x0, %xmm4, %xmm4 # | mulps %xmm3, %xmm4 # the other is already waiting in %xmm[0-3] addps %xmm4, %xmm5 addq ##代码##x4, %rsi # 5 preceding comments stride among the 4 blocks movaps %xmm5, (%rdx,%rcx) # store the resulting row, actually, a column addq ##代码##x10, %rcx # (matrices are stored in column-major order) cmpq ##代码##x40, %rcx jne .ROW ret .size matrixMultiplyASM, .-matrixMultiplyASM

It calculates a whole column of the resultant matrix per iteration, by processing four floats packed in 128-bit SSE registers. The full vectorisation is possible with a bit of math (operation reordering and aggregation) and mullps/addpsinstructions for parallel multiplication/addition of 4xfloat packages. The code reuses registers meant for passing parameters (%rdi, %rsi, %rdx: GNU/Linux ABI), benefits from (inner) loop unrolling and holds one matrix entirely in XMM registers to reduce memory reads. A you can see, I have researched the topic and took my time to implement it the best I can.

它通过处理封装在 128 位 SSE 寄存器中的四个浮点数,计算每次迭代所得矩阵的一整列。完全矢量化可以通过一些数学运算(操作重新排序和聚合)和mullps/addps指令进行 4xfloat 包的并行乘法/加法。该代码重用了用于传递参数的寄存器(%rdi, %rsi, %rdx: GNU/Linux ABI),受益于(内部)循环展开,并将一个矩阵完全保存在 XMM 寄存器中以减少内存读取。你可以看到,我已经研究了这个主题,并花时间尽我所能地实施它。

The naive C calculation conquering my code looks like this:

征服我的代码的朴素 C 计算如下所示:

##代码##

I have investigated the optimised assembly output of the above's C code which, while storing floats in XMM registers, does not involve any parallel operations– just scalar calculations, pointer arithmetic and conditional jumps. The compiler's code seems to be less deliberate, but it is still slightly more effective than my vectorised version expected to be about 4x faster. I'm sure that the general idea is correct – programmers do similar things with rewarding results. But what is wrong here? Are there any register allocation or instruction scheduling issues I am not aware of? Do you know any x86-64 assembly tools or tricks to support my battle against the machine?

我研究了上述 C 代码的优化汇编输出,它在 XMM 寄存器中存储浮点数时,不涉及任何并行操作——只是标量计算、指针算术和条件跳转。编译器的代码似乎不那么刻意,但它仍然比我预期的矢量化版本快 4 倍更有效。我确信总体思路是正确的——程序员做类似的事情并获得有益的结果。但是这里有什么问题呢?是否有任何我不知道的寄存器分配或指令调度问题?你知道任何 x86-64 组装工具或技巧来支持我与机器的战斗吗?

采纳答案by Krzysztof Abramowicz

There is a way to accelerate the code and outplay the compiler. It does not involve any sophisticated pipeline analysis or deep code micro-optimisation (which doesn't mean that it couldn't further benefit from these). The optimisation uses three simple tricks:

有一种方法可以加速代码并胜过编译器。它不涉及任何复杂的管道分析或深度代码微优化(这并不意味着它不能从这些中进一步受益)。优化使用三个简单的技巧:

  1. The function is now 32-byte aligned (which significantly boosted performance),

  2. Main loop goes inversely, which reduces comparison to a zero test (based on EFLAGS),

  3. Instruction-level address arithmetic proved to be faster than the "external" pointer calculation (even though it requires twice as much additions ?in 3/4 cases?). It shortened the loop body by four instructions and reduced data dependencies within its execution path. See related question.

  1. 该函数现在是 32 字节对齐的(这显着提高了性能),

  2. 主循环相反,这减少了与零测试的比较(基于 EFLAGS),

  3. 指令级地址算术被证明比“外部”指针计算更快(即使它需要两倍的加法?在 3/4 的情况下?)。它通过四个指令缩短了循环体,并减少了其执行路径中的数据依赖性。请参阅相关问题

Additionally, the code uses a relative jump syntax which suppresses symbol redefinition error, which occurs when GCC tries to inline it (after being placed within asmstatement and compiled with -O3).

此外,代码使用相对跳转语法来抑制符号重定义错误,该错误发生在 GCC 尝试内联它时(在放置在asm语句中并使用 编译后-O3)。

##代码##

This is the fastest x86-64 implementation I have seen so far. I will appreciate, vote up and accept any answer providing a faster piece of assembly for that purpose!

这是迄今为止我见过的最快的 x86-64 实现。我将不胜感激,投票并接受任何为此目的提供更快组装的答案!

回答by Z boson

4x4 matrix multiplication is 64 multiplications and 48 additions. Using SSE this can be reduced to 16 multiplications and 12 additions (and 16 broadcasts). The following code will do this for you. It only requires SSE (#include <xmmintrin.h>). The arrays A, B, and Cneed to be 16 byte aligned. Using horizontal instructions such as hadd(SSE3) and dpps(SSE4.1) will be less efficient(especially dpps). I don't know if loop unrolling will help.

4x4 矩阵乘法是 64 次乘法和 48 次加法。使用 SSE,这可以减少到 16 次乘法和 12 次加法(和 16 次广播)。以下代码将为您执行此操作。它只需要 SSE ( #include <xmmintrin.h>)。数组ABC需要按 16 字节对齐。使用诸如hadd(SSE3) 和dpps(SSE4.1) 之类的水平指令会降低效率(尤其是dpps)。我不知道循环展开是否会有所帮助。

##代码##

回答by Sparky

I wonder if transposing one of the matrices may be beneficial.

我想知道转置其中一个矩阵是否有益。

Consider how we multiply the following two matrices ...

考虑我们如何将以下两个矩阵相乘......

##代码##

This would result in ...

这会导致...

##代码##

Doing the dot product of a row and a column is a pain.

做一行和一列的点积是一种痛苦。

What if we transposed the second matrix before we multiplied?

如果我们在乘法之前转置第二个矩阵会怎样?

##代码##

Now instead doing the dot product of a row and column, we are doing the dot product of two rows. This could lend itself to better use of the SIMD instructions.

现在我们正在做两行的点积,而不是做一行和一列的点积。这有助于更好地使用 SIMD 指令。

Hope this helps.

希望这可以帮助。

回答by Dick Bertrand

Sandy Bridge an above extend the instruction set to support 8 element vector arithmetic. Consider this implementation.

Sandy Bridge 对上述指令集进行了扩展以支持 8 元素向量算法。考虑这个实现。

##代码##

回答by Miroslav Scaldov

Obviously you can fetch terms from four matrices at a time and multiply four matrices simultaneously using the same algorithm.

显然,您可以一次从四个矩阵中提取项,并使用相同的算法同时乘以四个矩阵。