任意サイズ正方行列乗算の最適化(その1)

今まで3回にわたって、行列サイズがコンパイル時にわかる場合の正方行列乗算の最適化を取り扱ってきました(行列乗算の最適化入門 - よーる行列乗算の最適化入門(マルチコア編) - よーる行列乗算の最適化入門(GPGPU編) - よーる)。 行列サイズがコンパイル時にわかる場合は、コンパイラの自動ベクトル化だけで限界に近い性能が出ました。

今回はちゃんとした行列積プログラム、つまり行列サイズがコンパイル時にわからない場合の正方行列乗算の最適化をやっていきます。 配列のオーバーラップを気にしているのか、どうにもコンパイラの自動ベクトル化ではうまくいかなかったので、ちょっとだけアセンブリ言語に手を出しました。

5×4レジスタブロッキングカーネル

行列乗算の最適化入門(マルチコア編) - よーるで示したように、レジスタブロッキングは5×4が最適でした。 これについて、なぜそうなるかの説明は難しいですが、以下の総合的なバランスによるものと思われます。

  • fma命令のレイテンシ(4cycle)があるため、i軸アンローリングを4倍より小さくするとかなり遠くの命令までアウトオブオーダー実行しないとパイプラインを埋められない
  • i軸を4倍アンローリングしても、パイプラインを埋めるのにぎりぎり足りる量しか供給できないため、何らかの擾乱が発生すると演算器を使いきれない(?)
  • なので、i軸は5倍アンローリングするのが最適
  • k軸のアンローリング強度は、あまり大きくしすぎるとカーネルのレイテンシが長くなって、遠くの命令までアウトオブオーダー実行できなくなる
  • k軸のアンローリング強度は、小さすぎるとロード命令とストア命令が多く発生してしまう
  • なので(?)k軸は4倍アンローリングするくらいがよい

そこで、最適だった5×4のレジスタブロッキングを施したカーネルアセンブリ言語で書きました。 以下のCコードと等価な計算をするはずです。

void rbk( size_t N, double* a, double* b, double* c, size_t nj ) {
    for( size_t j = 0; j < nj; ++j )
    for( size_t i = 0; i < 5; ++i )
    for( size_t k = 0; k < 4; ++k )
    {
        c[i*N+j] = fma( a[i*N+k], b[k*N+j], c[i*N+j] );
    }
}

アセンブリ言語で書いたこの関数は以下の通りです。

        .intel_syntax noprefix
        .text
        .p2align 4
        .globl rbk54
        .type  rbk54, @function
rbk54:
        .cfi_startproc

        lea     rax, [rdi * 8]
        lea     r10, [rax + rax]
        lea     r11, [rax + r10]
        lea     rdi, [r10 + r10]

        cmp     r8 ,  8
        jb      .L546

        vmovupd zmm20, [rdx]
        vmovupd zmm21, [rdx + rax]
        vmovupd zmm22, [rdx + r10]
        vmovupd zmm23, [rdx + r11]

        vmovupd zmm25, [rcx]
        vmovupd zmm26, [rcx + rax]
        vmovupd zmm27, [rcx + r10]
        vmovupd zmm28, [rcx + r11]
        vmovupd zmm29, [rcx + rdi]

        vbroadcastsd    zmm0 , [rsi]
        vfmadd231pd     zmm25, zmm0 , zmm20
        vbroadcastsd    zmm4 , [rsi + rax]
        vfmadd231pd     zmm26, zmm4 , zmm20
        vbroadcastsd    zmm8 , [rsi + r10]
        vfmadd231pd     zmm27, zmm8 , zmm20
        vbroadcastsd    zmm12, [rsi + r11]
        vfmadd231pd     zmm28, zmm12, zmm20
        vbroadcastsd    zmm16, [rsi + rdi]
        vfmadd231pd     zmm29, zmm16, zmm20

        vbroadcastsd    zmm1 , [rsi + 8]
        vfmadd231pd     zmm25, zmm1 , zmm21
        vbroadcastsd    zmm5 , [rsi + rax + 8]
        vfmadd231pd     zmm26, zmm5 , zmm21
        vbroadcastsd    zmm9 , [rsi + r10 + 8]
        vfmadd231pd     zmm27, zmm9 , zmm21
        vbroadcastsd    zmm13, [rsi + r11 + 8]
        vfmadd231pd     zmm28, zmm13, zmm21
        vbroadcastsd    zmm17, [rsi + rdi + 8]
        vfmadd231pd     zmm29, zmm17, zmm21

        vbroadcastsd    zmm2 , [rsi + 16]
        vfmadd231pd     zmm25, zmm2 , zmm22
        vbroadcastsd    zmm6 , [rsi + rax + 16]
        vfmadd231pd     zmm26, zmm6 , zmm22
        vbroadcastsd    zmm10, [rsi + r10 + 16]
        vfmadd231pd     zmm27, zmm10, zmm22
        vbroadcastsd    zmm14, [rsi + r11 + 16]
        vfmadd231pd     zmm28, zmm14, zmm22
        vbroadcastsd    zmm18, [rsi + rdi + 16]
        vfmadd231pd     zmm29, zmm18, zmm22

        vbroadcastsd    zmm3 , [rsi + 24]
        vfmadd231pd     zmm25, zmm3 , zmm23
        vmovupd [rcx], zmm25

        vbroadcastsd    zmm7 , [rsi + rax + 24]
        vfmadd231pd     zmm26, zmm7 , zmm23
        vmovupd [rcx + rax], zmm26

        vbroadcastsd    zmm11, [rsi + r10 + 24]
        vfmadd231pd     zmm27, zmm11, zmm23
        vmovupd [rcx + r10], zmm27

        vbroadcastsd    zmm15, [rsi + r11 + 24]
        vfmadd231pd     zmm28, zmm15, zmm23
        vmovupd [rcx + r11], zmm28

        vbroadcastsd    zmm19, [rsi + rdi + 24]
        vfmadd231pd     zmm29, zmm19, zmm23
        vmovupd [rcx + rdi], zmm29

        add     rdx, 64
        add     rcx, 64
        sub     r8 ,  8

.L545:
        cmp     r8 ,  8
        jb      .L546

        vmovupd zmm20, [rdx]
        vmovupd zmm21, [rdx + rax]
        vmovupd zmm22, [rdx + r10]
        vmovupd zmm23, [rdx + r11]

        vmovupd zmm25, [rcx]
        vmovupd zmm26, [rcx + rax]
        vmovupd zmm27, [rcx + r10]
        vmovupd zmm28, [rcx + r11]
        vmovupd zmm29, [rcx + rdi]

        vfmadd231pd     zmm25, zmm0 , zmm20
        vfmadd231pd     zmm26, zmm4 , zmm20
        vfmadd231pd     zmm27, zmm8 , zmm20
        vfmadd231pd     zmm28, zmm12, zmm20
        vfmadd231pd     zmm29, zmm16, zmm20

        vfmadd231pd     zmm25, zmm1 , zmm21
        vfmadd231pd     zmm26, zmm5 , zmm21
        vfmadd231pd     zmm27, zmm9 , zmm21
        vfmadd231pd     zmm28, zmm13, zmm21
        vfmadd231pd     zmm29, zmm17, zmm21

        vfmadd231pd     zmm25, zmm2 , zmm22
        vfmadd231pd     zmm26, zmm6 , zmm22
        vfmadd231pd     zmm27, zmm10, zmm22
        vfmadd231pd     zmm28, zmm14, zmm22
        vfmadd231pd     zmm29, zmm18, zmm22

        vfmadd231pd     zmm25, zmm3 , zmm23
        vmovupd [rcx], zmm25

        vfmadd231pd     zmm26, zmm7 , zmm23
        vmovupd [rcx + rax], zmm26

        vfmadd231pd     zmm27, zmm11, zmm23
        vmovupd [rcx + r10], zmm27

        vfmadd231pd     zmm28, zmm15, zmm23
        vmovupd [rcx + r11], zmm28

        vfmadd231pd     zmm29, zmm19, zmm23
        vmovupd [rcx + rdi], zmm29

        add     rdx, 64
        add     rcx, 64
        sub     r8 ,  8

        jnz     .L545
        ret

.L546:
        mov     r9, 1
        shlx    r9, r9, r8
        sub     r9, 1
        kmovb   k1, r9d

        vmovupd zmm20{k1}, [rdx]
        vmovupd zmm21{k1}, [rdx + rax]
        vmovupd zmm22{k1}, [rdx + r10]
        vmovupd zmm23{k1}, [rdx + r11]

        vmovupd zmm25{k1}, [rcx]
        vmovupd zmm26{k1}, [rcx + rax]
        vmovupd zmm27{k1}, [rcx + r10]
        vmovupd zmm28{k1}, [rcx + r11]
        vmovupd zmm29{k1}, [rcx + rdi]

        vfmadd231pd     zmm25, zmm0 , zmm20
        vfmadd231pd     zmm26, zmm4 , zmm20
        vfmadd231pd     zmm27, zmm8 , zmm20
        vfmadd231pd     zmm28, zmm12, zmm20
        vfmadd231pd     zmm29, zmm16, zmm20

        vfmadd231pd     zmm25, zmm1 , zmm21
        vfmadd231pd     zmm26, zmm5 , zmm21
        vfmadd231pd     zmm27, zmm9 , zmm21
        vfmadd231pd     zmm28, zmm13, zmm21
        vfmadd231pd     zmm29, zmm17, zmm21

        vfmadd231pd     zmm25, zmm2 , zmm22
        vfmadd231pd     zmm26, zmm6 , zmm22
        vfmadd231pd     zmm27, zmm10, zmm22
        vfmadd231pd     zmm28, zmm14, zmm22
        vfmadd231pd     zmm29, zmm18, zmm22

        vfmadd231pd     zmm25, zmm23, [rsi + 24]{1to8}
        vmovupd [rcx]{k1}, zmm25

        vfmadd231pd     zmm26, zmm23, [rsi + rax + 24]{1to8}
        vmovupd [rcx + rax]{k1}, zmm26

        vfmadd231pd     zmm27, zmm23, [rsi + r10 + 24]{1to8}
        vmovupd [rcx + r10]{k1}, zmm27

        vfmadd231pd     zmm28, zmm23, [rsi + r11 + 24]{1to8}
        vmovupd [rcx + r11]{k1}, zmm28

        vfmadd231pd     zmm29, zmm23, [rsi + rdi + 24]{1to8}
        vmovupd [rcx + rdi]{k1}, zmm29

        ret
        .cfi_endproc
        .size rbk54, .-rbk54

        .section        .note.GNU-stack,"",@progbits

アセンブリ言語実装の解説

概略

この実装は、以下の特徴を持ちます。

  1. 呼び出し先保存レジスタを使わないで動作する(汎用レジスタは以下しか使っていない)
    • 引数(rdirsirdxrcxr8
    • その他の呼び出し元保存レジスタr9r10r11rax
  2. njがいくつであっても動作する(特に、8の倍数を仮定しない)
  3. vbroadcastsd命令を固めないで、vfmadd命令を間に挟んでいる
  4. j-k-iの順番でループをネストしている

まず、1.を守ることで、余分な命令を削減しています。 コンパイラが出力するコードは、これを守っていないため、全体的に命令数が多くなりがちです。 行列乗算コードは、基本的には演算ボトルネックですが、カーネルの命令数が増えるとアウトオブオーダー実行の恩恵(擾乱耐性)が減ってしまうため、なるべく命令数を減らしたほうが良いです。

次に、2.が、外側のコードを書くときに重要な特性です。 コンパイラの出力するコードではここがあまりうまくいかない傾向にあります。 マスクは汎用レジスタ上で簡単に作れるので、kmov命令を使ってマスクレジスタに移動して使います。

3.は、Ryzen 9 7950Xにおける性能を高める上で、最後の数パーセント向上に寄与します。 実際、これの導入で1120 GFLOPSくらいだった性能が1150 GFLOPSくらいの性能まで向上しました(3%弱)。 これの理由はよくわかっていませんが、おそらく発行ポート割り当ての問題だと思っています。 vfmadd命令はFP0ポートかFP1ポートを使用し、vbroadcastsd命令はFP1ポートかFP2ポートを使用します。 よって、性能を最大化するためには、vbroadcastsd命令はFP2ポートに行ってほしいです。 vbroadcastsd命令を固めて配置するよりもvfmadd命令で間に挟んだ方が、vbroadcastsd命令がFP1ポートに発行されてしまうことを防止できる感じがあります。

4.は、理由はよくわかりませんが、この方が高速に動作したのでそうしています。 以下の三つの理由が考えられますが、本当の理由が何なのか、あるいはこれら以外の何か別の理由なのかはよくわかりません。

  • 早く終わらせるべき命令を早く発行できる、という意味でスケジューリングアルゴリズムに都合がいい順である
  • スケジューラに滞留する時間が最小化される、という意味でスケジューリングアルゴリズムに都合がいい順である
  • 完了した命令がリタイアしやすい、という意味で遠くの命令までアウトオブオーダー実行しやすい順である

以下、コードを細かく見ていきます。

端数判定

まず、njが8未満かを確認しています。 njが8未満の場合は、マスク付き実行しないといけないので、端数処理用コードにジャンプします。

1周目の基本ブロック

Nだけずれたアドレスを算出するために必要な値を計算した後、vbroadcastsdvfmaddを交互に繰り返しているだけです。 一周終わったら、njを8減らします。

2周目以降の基本ブロック

これも簡単です。 まずはnjが8未満かを確認し、その場合は端数処理用コードにジャンプします。 後はvfmaddをたくさんやるだけです。 njを8減らして0になっていればreturn、そうでなければこの基本ブロックを繰り返します。

端数処理用コード

マスク付きで実行すること以外はほとんど同じです。 関数先頭で端数判定された場合はvbroadcastが終わっていないので、メモリオペランドのブロードキャストを使います。 関数先頭以外で端数判定された場合はvbroadcastされているのでそれを転用したほうが効率が良いはずですが、気にしないことにします。

5×4レジスタブロッキングできなかった場合のカーネル

都合のいいNが来るとは限らないので、5×4のレジスタブロッキングができるとは限りません。 そこで、ninkの値に応じて可変サイズのレジスタブロッキングができるカーネルを用意しました。 x86の引数レジスタは6個しかないので、ninkの情報をnikにまとめて渡します。 短い命令で比較できるようにするため、nk-1 << 4 | 16 >> niという変な渡し方をしています。

// rbkxx( N, &a[i*N+k], &b[k*N+j], &c[i*N+j], nj, nk-1 << 4 | 16 >> ni ); のように呼び出してね
void rbkxx( size_t N, double* a, double* b, double* c, size_t nj, size_t nki ) {
    size_t nk = /* ... */;
    size_t ni = /* ... */;
    for( size_t j = 0; j < nj; ++j )
    for( size_t i = 0; i < ni; ++i )
    for( size_t k = 0; k < nk; ++k )
    {
        c[i*N+j] = fma( a[i*N+k], b[k*N+j], c[i*N+j] );
    }
}

アセンブリコードは、先のコードに分岐をたくさん入れただけです。 ちまちま分岐しないといけないので長いです。

        .intel_syntax noprefix
        .text
        .p2align 4
        .globl rbkxx
        .type  rbkxx, @function
rbkxx:
        .cfi_startproc

        lea     rax, [rdi * 8]
        lea     r10, [rax + rax]
        lea     r11, [rax + r10]
        lea     rdi, [r10 + r10]

        cmp     r8 ,  8
        jb      .L546

        vmovupd zmm20, [rdx]
        cmp     r9 , 16
        jb      .L54a1
        vmovupd zmm21, [rdx + rax]
        cmp     r9 , 32
        jb      .L54a1
        vmovupd zmm22, [rdx + r10]
        cmp     r9 , 48
        jb      .L54a1
        vmovupd zmm23, [rdx + r11]

.L54a1:
        vmovupd zmm25, [rcx]
        test    r9 , 8
        jnz     .L54b1
        vmovupd zmm26, [rcx + rax]
        test    r9 , 4
        jnz     .L54b1
        vmovupd zmm27, [rcx + r10]
        test    r9 , 2
        jnz     .L54b1
        vmovupd zmm28, [rcx + r11]
        test    r9 , 1
        jnz     .L54b1
        vmovupd zmm29, [rcx + rdi]

.L54b1:
        vbroadcastsd    zmm0 , [rsi]
        vfmadd231pd     zmm25, zmm0 , zmm20
        cmp     r9 , 16
        jb      .L54a2
        vbroadcastsd    zmm1 , [rsi + 8]
        vfmadd231pd     zmm25, zmm1 , zmm21
        cmp     r9 , 32
        jb      .L54a2
        vbroadcastsd    zmm2 , [rsi + 16]
        vfmadd231pd     zmm25, zmm2 , zmm22
        cmp     r9 , 48
        jb      .L54a2
        vbroadcastsd    zmm3 , [rsi + 24]
        vfmadd231pd     zmm25, zmm3 , zmm23
.L54a2:
        vmovupd [rcx], zmm25
        test    r9 , 8
        jnz     .L54b2

        vbroadcastsd    zmm4 , [rsi + rax]
        vfmadd231pd     zmm26, zmm4 , zmm20
        cmp     r9 , 16
        jb      .L54a3
        vbroadcastsd    zmm5 , [rsi + rax + 8]
        vfmadd231pd     zmm26, zmm5 , zmm21
        cmp     r9 , 32
        jb      .L54a3
        vbroadcastsd    zmm6 , [rsi + rax + 16]
        vfmadd231pd     zmm26, zmm6 , zmm22
        cmp     r9 , 48
        jb      .L54a3
        vbroadcastsd    zmm7 , [rsi + rax + 24]
        vfmadd231pd     zmm26, zmm7 , zmm23
.L54a3:
        vmovupd [rcx + rax], zmm26
        test    r9 , 4
        jnz     .L54b2

        vbroadcastsd    zmm8 , [rsi + r10]
        vfmadd231pd     zmm27, zmm8 , zmm20
        cmp     r9 , 16
        jb      .L54a4
        vbroadcastsd    zmm9 , [rsi + r10 + 8]
        vfmadd231pd     zmm27, zmm9 , zmm21
        cmp     r9 , 32
        jb      .L54a4
        vbroadcastsd    zmm10, [rsi + r10 + 16]
        vfmadd231pd     zmm27, zmm10, zmm22
        cmp     r9 , 48
        jb      .L54a4
        vbroadcastsd    zmm11, [rsi + r10 + 24]
        vfmadd231pd     zmm27, zmm11, zmm23
.L54a4:
        vmovupd [rcx + r10], zmm27
        test    r9 , 2
        jnz     .L54b2

        vbroadcastsd    zmm12, [rsi + r11]
        vfmadd231pd     zmm28, zmm12, zmm20
        cmp     r9 , 16
        jb      .L54a5
        vbroadcastsd    zmm13, [rsi + r11 + 8]
        vfmadd231pd     zmm28, zmm13, zmm21
        cmp     r9 , 32
        jb      .L54a5
        vbroadcastsd    zmm14, [rsi + r11 + 16]
        vfmadd231pd     zmm28, zmm14, zmm22
        cmp     r9 , 48
        jb      .L54a5
        vbroadcastsd    zmm15, [rsi + r11 + 24]
        vfmadd231pd     zmm28, zmm15, zmm23
.L54a5:
        vmovupd [rcx + r11], zmm28
        test    r9 , 1
        jnz     .L54b2

        vbroadcastsd    zmm16, [rsi + rdi]
        vfmadd231pd     zmm29, zmm16, zmm20
        cmp     r9 , 16
        jb      .L54a6
        vbroadcastsd    zmm17, [rsi + rdi + 8]
        vfmadd231pd     zmm29, zmm17, zmm21
        cmp     r9 , 32
        jb      .L54a6
        vbroadcastsd    zmm18, [rsi + rdi + 16]
        vfmadd231pd     zmm29, zmm18, zmm22
        cmp     r9 , 48
        jb      .L54a6
        vbroadcastsd    zmm19, [rsi + rdi + 24]
        vfmadd231pd     zmm29, zmm19, zmm23
.L54a6:
        vmovupd [rcx + rdi], zmm29

.L54b2:
        add     rdx, 64
        add     rcx, 64
        sub     r8 ,  8

.L545:
        cmp     r8 ,  8
        jb      .L546

        vmovupd zmm20, [rdx]
        cmp     r9 , 16
        jb      .L54a7
        vmovupd zmm21, [rdx + rax]
        cmp     r9 , 32
        jb      .L54a7
        vmovupd zmm22, [rdx + r10]
        cmp     r9 , 48
        jb      .L54a7
        vmovupd zmm23, [rdx + r11]
.L54a7:

        vmovupd zmm25, [rcx]
        test    r9 , 8
        jnz     .L54b3
        vmovupd zmm26, [rcx + rax]
        test    r9 , 4
        jnz     .L54b3
        vmovupd zmm27, [rcx + r10]
        test    r9 , 2
        jnz     .L54b3
        vmovupd zmm28, [rcx + r11]
        test    r9 , 1
        jnz     .L54b3
        vmovupd zmm29, [rcx + rdi]
.L54b3:

        vfmadd231pd     zmm25, zmm0 , zmm20
        cmp     r9 , 16
        jb      .L54a8
        vfmadd231pd     zmm25, zmm1 , zmm21
        cmp     r9 , 32
        jb      .L54a8
        vfmadd231pd     zmm25, zmm2 , zmm22
        cmp     r9 , 48
        jb      .L54a8
        vfmadd231pd     zmm25, zmm3 , zmm23
.L54a8:
        vmovupd [rcx], zmm25
        test    r9 , 8
        jnz     .L54b4

        vfmadd231pd     zmm26, zmm4 , zmm20
        cmp     r9 , 16
        jb      .L54a9
        vfmadd231pd     zmm26, zmm5 , zmm21
        cmp     r9 , 32
        jb      .L54a9
        vfmadd231pd     zmm26, zmm6 , zmm22
        cmp     r9 , 48
        jb      .L54a9
        vfmadd231pd     zmm26, zmm7 , zmm23
.L54a9:
        vmovupd [rcx + rax], zmm26
        test    r9 , 4
        jnz     .L54b4

        vfmadd231pd     zmm27, zmm8 , zmm20
        cmp     r9 , 16
        jb      .L54a10
        vfmadd231pd     zmm27, zmm9 , zmm21
        cmp     r9 , 32
        jb      .L54a10
        vfmadd231pd     zmm27, zmm10, zmm22
        cmp     r9 , 48
        jb      .L54a10
        vfmadd231pd     zmm27, zmm11, zmm23
.L54a10:
        vmovupd [rcx + r10], zmm27
        test    r9 , 2
        jnz     .L54b4

        vfmadd231pd     zmm28, zmm12, zmm20
        cmp     r9 , 16
        jb      .L54a11
        vfmadd231pd     zmm28, zmm13, zmm21
        cmp     r9 , 32
        jb      .L54a11
        vfmadd231pd     zmm28, zmm14, zmm22
        cmp     r9 , 48
        jb      .L54a11
        vfmadd231pd     zmm28, zmm15, zmm23
.L54a11:
        vmovupd [rcx + r11], zmm28
        test    r9 , 1
        jnz     .L54b4

        vfmadd231pd     zmm29, zmm16, zmm20
        cmp     r9 , 16
        jb      .L54a12
        vfmadd231pd     zmm29, zmm17, zmm21
        cmp     r9 , 32
        jb      .L54a12
        vfmadd231pd     zmm29, zmm18, zmm22
        cmp     r9 , 48
        jb      .L54a12
        vfmadd231pd     zmm29, zmm19, zmm23
.L54a12:
        vmovupd [rcx + rdi], zmm29

.L54b4:
        add     rdx, 64
        add     rcx, 64
        sub     r8 ,  8

        jnz     .L545
        ret

.L546:
        mov     edi, 1
        shlx    edi, edi, r8d
        sub     edi, 1
        kmovb   k1, edi
        lea     rdi, [r10 + r10]

        vmovupd zmm20{k1}, [rdx]
        cmp     r9 , 16
        jb      .L54a13
        vmovupd zmm21{k1}, [rdx + rax]
        cmp     r9 , 32
        jb      .L54a13
        vmovupd zmm22{k1}, [rdx + r10]
        cmp     r9 , 48
        jb      .L54a13
        vmovupd zmm23{k1}, [rdx + r11]
.L54a13:

        vmovupd zmm25{k1}, [rcx]
        test    r9 , 8
        jnz     .L54b5
        vmovupd zmm26{k1}, [rcx + rax]
        test    r9 , 4
        jnz     .L54b5
        vmovupd zmm27{k1}, [rcx + r10]
        test    r9 , 2
        jnz     .L54b5
        vmovupd zmm28{k1}, [rcx + r11]
        test    r9 , 1
        jnz     .L54b5
        vmovupd zmm29{k1}, [rcx + rdi]
.L54b5:
        vfmadd231pd     zmm25, zmm20, [rsi]{1to8}
        cmp     r9 , 16
        jb      .L54a14
        vfmadd231pd     zmm25, zmm21, [rsi + 8]{1to8}
        cmp     r9 , 32
        jb      .L54a14
        vfmadd231pd     zmm25, zmm22, [rsi + 16]{1to8}
        cmp     r9 , 48
        jb      .L54a14
        vfmadd231pd     zmm25, zmm23, [rsi + 24]{1to8}
.L54a14:
        vmovupd [rcx]{k1}, zmm25
        test    r9 , 8
        jnz     .L54b6

        vfmadd231pd     zmm26, zmm20, [rsi + rax]{1to8}
        cmp     r9 , 16
        jb      .L54a15
        vfmadd231pd     zmm26, zmm21, [rsi + rax + 8]{1to8}
        cmp     r9 , 32
        jb      .L54a15
        vfmadd231pd     zmm26, zmm22, [rsi + rax + 16]{1to8}
        cmp     r9 , 48
        jb      .L54a15
        vfmadd231pd     zmm26, zmm23, [rsi + rax + 24]{1to8}
.L54a15:
        vmovupd [rcx + rax]{k1}, zmm26
        test    r9 , 4
        jnz     .L54b6

        vfmadd231pd     zmm27, zmm20, [rsi + r10]{1to8}
        cmp     r9 , 16
        jb      .L54a16
        vfmadd231pd     zmm27, zmm21, [rsi + r10 + 8]{1to8}
        cmp     r9 , 32
        jb      .L54a16
        vfmadd231pd     zmm27, zmm22, [rsi + r10 + 16]{1to8}
        cmp     r9 , 48
        jb      .L54a16
        vfmadd231pd     zmm27, zmm23, [rsi + r10 + 24]{1to8}
.L54a16:
        vmovupd [rcx + r10]{k1}, zmm27
        test    r9 , 2
        jnz     .L54b6

        vfmadd231pd     zmm28, zmm20, [rsi + r11]{1to8}
        cmp     r9 , 16
        jb      .L54a17
        vfmadd231pd     zmm28, zmm21, [rsi + r11 + 8]{1to8}
        cmp     r9 , 32
        jb      .L54a17
        vfmadd231pd     zmm28, zmm22, [rsi + r11 + 16]{1to8}
        cmp     r9 , 48
        jb      .L54a17
        vfmadd231pd     zmm28, zmm23, [rsi + r11 + 24]{1to8}
.L54a17:
        vmovupd [rcx + r11]{k1}, zmm28
        test    r9 , 1
        jnz     .L54b6

        vfmadd231pd     zmm29, zmm20, [rsi + rdi]{1to8}
        cmp     r9 , 16
        jb      .L54a18
        vfmadd231pd     zmm29, zmm21, [rsi + rdi + 8]{1to8}
        cmp     r9 , 32
        jb      .L54a18
        vfmadd231pd     zmm29, zmm22, [rsi + rdi + 16]{1to8}
        cmp     r9 , 48
        jb      .L54a18
        vfmadd231pd     zmm29, zmm23, [rsi + rdi + 24]{1to8}
.L54a18:
        vmovupd [rcx + rdi]{k1}, zmm29
.L54b6:

        ret
        .cfi_endproc
        .size rbkxx, .-rbkxx

        .section        .note.GNU-stack,"",@progbits

外側のコード

外側のコードまで全部アセンブリ言語で作るのは大変なので、C++コンパイラに任せることにします。 行列乗算の最適化入門(マルチコア編) - よーるで示したように、L1Dキャッシュブロッキングは、50(i2)×40(k2)×40(j2)が最適でした。 その外側のキャッシュブロッキングは、50(i1)×80(k1)×80(j1)が最適と示していましたが、これは4000の約数でないといけない制約の下での最適値であり、実際には50(i)×120(k)×120(j)が良いようです。 基本的には、このサイズでの分割を行っていきます。 分割サイズがわかっている場合はテンプレート引数に入れることで最適化を支援します。

内側に近いループではテンプレートを明示的に特殊化して最適化を支援します。 6段目(j2)と8段目(i3)は分割サイズを調整する(極端に小さい分割サイズをできるだけ回避する)小細工を入れていますが、これによる性能向上はあまりなさそうです。 7段目(k3)は実行時引数をテンプレート引数に戻しているだけです。 これらの明示的特殊化はあまり意味がなさそう(特に端数が出ないN = 2400等では全く効果がないはず)なのですが、なくすと性能が下がったので残してあります。 コンパイラは9重もの再帰呼び出しをうまく取り扱えないのかもしれません。

static constexpr std::size_t cutn( std::size_t Q ) {
    switch( Q ) {
    case  1: return 50;  // L3C_I...i1
    case  2: return 120; // L3C_K...k1
    case  3: return 120; // L3C_J...j1
    case  4: return 50;  // L1D_I...i2
    case  5: return 40;  // L1D_K...k2
    case  6: return 40;  // L1D_J...j2
    case  7: return 5;   // Reg_I...i3
    case  8: return 4;   // Reg_K...k3
    default: return 0;
    }
}

extern "C" {
    void rbk54( std::size_t, double*, double*, double*, std::size_t );
    void rbkxx( std::size_t, double*, double*, double*, std::size_t, std::size_t );
}

template<std::size_t Q, std::size_t NI_, std::size_t NK_, std::size_t NJ_>
struct mmi {
    static void cut( std::size_t N, std::size_t ni, std::size_t nk, std::size_t nj, int i, int k, int j, double* a, double* b, double* c ) {
        const std::size_t NI = NI_ == 0 ? ni : NI_;
        static constexpr std::size_t CUT_I = cutn( Q );
        for( std::size_t i1 = 0; i1 < NI; i1 += CUT_I ) {
            if( i1 + CUT_I <= NI ) {
                mmi<Q+1, NK_, NJ_, CUT_I>::cut( N, nk, nj, CUT_I  , k, j, i + i1, a, b, c );
            } else {
                mmi<Q+1, NK_, NJ_, 0    >::cut( N, nk, nj, NI - i1, k, j, i + i1, a, b, c );
            }
        }
    }
};


template<std::size_t NJ_, std::size_t NI_, std::size_t NK_>
struct mmi<9, NJ_, NI_, NK_> {
    static void cut( std::size_t N, std::size_t nj, std::size_t ni, std::size_t nk, int j0, int i0, int k0, double* a, double* b, double* c ) {
        const std::size_t NI = NI_ == 0 ? ni : NI_;
        const std::size_t NK = NK_ == 0 ? nk : NK_;
        if( NI == 5 && NK == 4 ) {
            rbk54( N, &a[i0*N+k0], &b[k0*N+j0], &c[i0*N+j0], nj );
        } else {
            rbkxx( N, &a[i0*N+k0], &b[k0*N+j0], &c[i0*N+j0], nj, NK-1 << 4 | 16 >> NI );
        }
    }
};

template<std::size_t NJ_, std::size_t NI_>
struct mmi<8, 0, NJ_, NI_> {
    static void cut( std::size_t N, std::size_t nk, std::size_t nj, std::size_t ni, int k0, int j0, int i0, double* a, double* b, double* c ) {
top:;
        switch( nk ) {
        case  1: mmi<9, NJ_, NI_, 1>::cut( N, nj, ni,  1, j0, i0, k0, a, b, c ); break;

        case  5: mmi<9, NJ_, NI_, 3>::cut( N, nj, ni,  3, j0, i0, k0, a, b, c ); k0 += 3;
        case  2: mmi<9, NJ_, NI_, 2>::cut( N, nj, ni,  2, j0, i0, k0, a, b, c ); break;

        case  9: mmi<9, NJ_, NI_, 3>::cut( N, nj, ni,  3, j0, i0, k0, a, b, c ); k0 += 3;
        case  6: mmi<9, NJ_, NI_, 3>::cut( N, nj, ni,  3, j0, i0, k0, a, b, c ); k0 += 3;
        case  3: mmi<9, NJ_, NI_, 3>::cut( N, nj, ni,  3, j0, i0, k0, a, b, c ); break;

        case  4: mmi<9, NJ_, NI_, 4>::cut( N, nj, ni,  4, j0, i0, k0, a, b, c ); break;

        default: mmi<9, NJ_, NI_, 4>::cut( N, nj, ni,  4, j0, i0, k0, a, b, c ); nk -= 4; k0 += 4; goto top;
        }
    }
};

template<std::size_t NK_, std::size_t NJ_>
struct mmi<7, 0, NK_, NJ_> {
    static void cut( std::size_t N, std::size_t ni, std::size_t nk, std::size_t nj, int i0, int k0, int j0, double* a, double* b, double* c ) {
top:;
        switch( ni ) {
        case  1: mmi<8, NK_, NJ_, 1>::cut( N, nk, nj,  1, k0, j0, i0, a, b, c ); break;
        case  2: mmi<8, NK_, NJ_, 2>::cut( N, nk, nj,  2, k0, j0, i0, a, b, c ); break;
        case  3: mmi<8, NK_, NJ_, 3>::cut( N, nk, nj,  3, k0, j0, i0, a, b, c ); break;
        case  4: mmi<8, NK_, NJ_, 4>::cut( N, nk, nj,  4, k0, j0, i0, a, b, c ); break;
        case  5: mmi<8, NK_, NJ_, 5>::cut( N, nk, nj,  5, k0, j0, i0, a, b, c ); break;

        default: mmi<8, NK_, NJ_, 5>::cut( N, nk, nj,  5, k0, j0, i0, a, b, c ); ni -= 5; i0 += 5; goto top;
        }
    }
};

template<std::size_t NJ_, std::size_t NI_, std::size_t NK_>
struct mmi<6, NJ_, NI_, NK_> {
    static void cut( std::size_t N, std::size_t nj, std::size_t ni, std::size_t nk, int j0, int i0, int k0, double* a, double* b, double* c ) {
        const std::size_t NJ = NJ_ == 0 ? nj : NJ_;
        if( NJ == 120 ) {
            mmi<7, NI_, NK_, 40>::cut( N, ni, nk, 40     , i0, k0, j0     , a, b, c );
            mmi<7, NI_, NK_, 40>::cut( N, ni, nk, 40     , i0, k0, j0 + 40, a, b, c );
            mmi<7, NI_, NK_, 40>::cut( N, ni, nk, 40     , i0, k0, j0 + 80, a, b, c );
        } else if( NJ >= 80 ) {
            mmi<7, NI_, NK_, 40>::cut( N, ni, nk, 40     , i0, k0, j0     , a, b, c );
            mmi<7, NI_, NK_, 0 >::cut( N, ni, nk, NJ - 40, i0, k0, j0 + 40, a, b, c );
        } else {
            mmi<7, NI_, NK_, 0 >::cut( N, ni, nk, NJ     , i0, k0, j0     , a, b, c );
        }
    }
};

void mm( std::size_t tid, std::size_t N, double* a, double* b, double* c ) {
    assert( tid < 16 );
    const std::size_t TH_I = (N + 15) / 16;
    const std::size_t i_begin = TH_I * tid;
    const std::size_t i_end   = TH_I * (tid + 1);
    if( i_begin >= N ) { return; }
    if( i_end <= N ) {
        mmi<1, 0, 0, 0>::cut( N, TH_I, N, N, i_begin, 0, 0, a, b, c );
    } else {
        mmi<1, 0, 0, 0>::cut( N, N % TH_I, N, N, i_begin, 0, 0, a, b, c );
    }
}

void mm( std::size_t N, double* a, double* b, double* c ) {
#pragma omp parallel for
    for( int tid = 0; tid < 16; ++tid ) {
        mm( tid, N, a, b, c );
    }
}

速度測定

以下の環境で速度を測ってみました。 また、OpenBLASのDGEMM関数の速度と比較してみました。

  • Ryzen 9 7950X(4.9 GHzくらい?)
  • DDR5-4800(38.4 GB/s)
  • g++ 13.2.0

なお、OpenBLASはこの環境で自分でビルドしました。

実験方法

以下のコードを10回実行してもっともよかった結果を採用しました。 以下のコードはmmを5回実行するので、都合50回のうち最も良かった結果ということになります。 なお、5回実行のうち初回の実行はキャッシュヒット率が悪いため、良い性能となることはほぼありません。

// ==== benchmark ====
void test( std::size_t N, int mode, double* a, double* b, double* c ) {
    for( int i = 0; i < N*N; ++i ) {
       c[i] = 0.0;
    }

    std::uint64_t hash = 0;
    const auto start = std::chrono::system_clock::now();
    if( mode ) {
        mm( N, a, b, c );
    } else {
        for( int i = 0; i < N; ++i )
        for( int k = 0; k < N; ++k )
        for( int j = 0; j < N; ++j )
        c[i*N+j] = std::fma( a[i*N+k], b[k*N+j], c[i*N+j] );
    }
    const auto finish = std::chrono::system_clock::now();
    for( int i = 0; i < N; ++i ) {
        for( int j = 0; j < N; ++j ) {
            hash = hash << 1 ^ hash >> 63;
            hash ^= std::bit_cast<std::uint64_t>( c[i*N+j] );
        }
    }
    const double flop_per_fma = 2.0;
    const double insn_per_fma = 1.0 / 8;
    const double s = std::chrono::duration_cast<std::chrono::nanoseconds>( finish - start ).count() * 1e-9;
    std::printf( "[%04ld] %.8f sec [ %*.3f GX/s], %*.3f GFLOPS, %.3e FPIns/s [hash=%016lx]\n", N, s, 7, N*N*N/s*1e-9, 8, N*N*N*flop_per_fma/s*1e-9, N*N*N*insn_per_fma/s, hash );
}

// ==== main ====

int main( int argc, char* argv[] ) {
    std::ifstream ifs( argv[1], std::ios::binary );
    const std::size_t N = std::atoi( argv[2] );

    double* a  = new (std::align_val_t{64}) double [N*N];
    double* b  = new (std::align_val_t{64}) double [N*N];
    double* c  = new (std::align_val_t{64}) double [N*N];
    for( int i = 0; i < N*N; ++i ) {
        ifs.read( (char*)&a[i], 8 );
        ifs.read( (char*)&b[i], 8 );
    }

    // test( N, 0, a, b, c );
    for( int i = 0; i < 5; ++i ) {
        test( N, 1, a, b, c );
    }
}

なお、test( N, 0, a, b, c );コメントアウトを外すことで、逐次実行の場合と結果が変わっていないことを確認できます。

また、入力は以下のプログラムで作りました。 [-1, 1)に一様分布する倍精度浮動小数点数の乱数を生成することを意図しています(std::uniform_real_distributionを使わないのは、結果の移植性を考慮しているためです*1)。

#include <cstdint>
#include <fstream>
#include <random>
#include <bit>

int main() {
        std::ofstream ofs( "d.bin", std::ios::binary );
        std::mt19937_64 mt;
        for( int i = 0; i < 10000*10000; ++i ) {
                __uint128_t x = mt();
                x <<= 64;
                x |= mt();
                __int128_t y = std::bit_cast<__int128_t>( x );
                double d = y * 0x1.p-127;
                ofs.write( (char*)&d, 8 );
        }
}

また、OpenBLASのDGEMMは以下のように呼び出しました。

cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, a, N, b, N, 1.0, c, N);

実のところ、最初は間違ってβ=0.0で測定していました。 せっかくなので、以下ではβ=0.0の結果も掲載します。

測定結果

図1は、行列サイズNごとの、計算にかかった時間から算出したGFLOPS値です。

図1: 行列サイズNごとのGFLOPS値(Ryzen 9 7950X)

この結果から、いくつも面白いことがわかります。

まず、私のコードは行列サイズが256、512、1024、1536、2048、などの大きな二冪を約数に含む数の周囲で性能が大きく低下しています。 これはよく知られているように、セットアソシアティブキャッシュのインデクスが衝突して競合性ミスが大量発生することが原因です。 一方で、OpenBLASの実装はそのようなことが起きていません。バッファを用意しているのが理由と思われます(参考:密行列積の高速化 OpenBLAS編 #数値計算 - Qiita)。 また、私のコードは行列サイズが512n+173、512n+259、512n+343などの部分でも大きな性能低下がみられます。 これらは512(n+1/3)、512(n+1/2)、512(n+2/3)に近く、同様に競合性ミスが増えると予測される行列サイズですが、微妙にずれている原因はよくわかりませんでした(私のコードが端数に弱い影響でしょうか)。

次に、OpenBLASの実装は、行列サイズが100の所で性能の傾向が不連続になっています。 これは、行列サイズが100以下の場合はマルチスレッドではなくシングルスレッドで行うためであると思われます。 実際、行列サイズが100以下の場合の性能は、86 GFLOPS程度と、シングルスレッドの理論性能の限界に近くなっています。 行列サイズが小さいときにマルチスレッドを使わないのは、スレッド立ち上げオーバーヘッドの影響を考慮していると思われます。 なお、私の実装では、行列サイズ50程度の所でシングルスレッドとマルチスレッドの優劣が入れ替わりました(図2)。

また、私のコードは行列サイズが小さいときの性能の立ち上がりがOpenBLASより急激になっています。 これの理由はよくわかりませんでしたが、キャッシュブロッキングの性能が最も高くなるような分割を(手動で)探し当てたことによるかもしれません。 OpenBLASはL1キャッシュとL2キャッシュのサイズ・ラインサイズを#define文で設定していますが、その他はSkylake-Xのコードを流用しているようです(?)。 マルチスレッドの時に性能に大きく効いてくるのは、全てのコアで共有されているL3キャッシュのスループットとレイテンシだったりするので、それが原因で性能の立ち上がりが悪いのかもしれません。

最後に、β=0.0はβ=1.0の時と比べて有意に性能が低いことがわかります。 OpenBLASでは、β≠1.0の時、行列Cをβ倍する前処理が行われるようです(参考:密行列積の高速化 OpenBLAS編 #数値計算 - Qiita)。 これは行列積の 2N^3FLOPに対して N^2FLOPのコストしかかからない処理ですが、しかしながら無視できない重たい処理です。 なぜなら、この処理のB/F比はなんと8もあり、一瞬でメモリバンド幅上限に達してしまうからです。 L3キャッシュに乗れば大丈夫かというと、実際のところはメモリバンド幅どころかL3キャッシュバンド幅上限も容易に達してしまいます。 Ryzen 9 7950XのL3キャッシュのバンド幅は32 Byte/cycleしかないので、計算能力256FLOP/cycleに対してB/F比が0.125(要求の1/64)しかありません。 L3キャッシュに乗らない場合はさらに悲惨で、38.4 GB/s≒8 Byte/cycleしかないので、B/F比が0.03(要求の1/260)しかありません。 なので、β倍するコードは実質的な計算コストが 128N^2FLOP分(L3キャッシュに乗らない場合は 520N^2FLOP分)もかかる、無視できない処理となります。

図2: 行列サイズが小さいところの拡大図。シングルスレッド実行の場合も追加した
ングルスレッド実行の場合も追加した

図2を見ると、周期8の明確なパターンが存在することがわかります。 レジスタブロッキングは5の倍数や4の倍数からのずれが重要そうなのに、そのようなパターンがほとんど見られないのは不思議です。 OpenBLASの性能にも同様のパターンがみられることを考えると、vmovupdの特性(アラインがあっていなくても動作するが、あっていないと遅い)を見ているのかもしれません。

図3: Nが8n-7~8n中でGFLOPS値が最大のものと最小のものをプロットした

図3は、Nが8n-7~8nのうち、GFLOPS値が最も低かったものと高かったものを選び出してプロットしたものです。 最も高くなるのはほとんどの場合、N=8nかN=8n-4の地点でした。 一方、最も低くなるのはNが十分大きければほとんどの場合、N=8n-5かN=8n-3かN=8n-1の地点でした。 端数処理はもうちょっと真面目にやらないと、Nが4の倍数以外の場合の時にOpenBLAS並みの性能を出すことができないことがわかります。 4の倍数でないと性能低下する原因はよくわかりませんが、vmovupd命令でアラインが32 Byteの倍数になっていないといけないとベストな性能が出ないのかなと思っています。 端数処理コードが実行されるのは全体の \Theta(1/N)程度なので、そこの実装が適当でも性能に大きな影響はないはずだからです。

まとめ

正方行列の乗算をもう少し最適化してみました。 行列のサイズが入力として与えられる場合は、さすがにコンパイラの自動ベクトル化だけでは厳しかったので、少しだけアセンブリ言語に手を出しました。 きりのいいサイズの入力の場合には1150 GFLOPS(理論性能の91.6%)もの性能を出すことができました。

また、性能をOpenBLASの実装と比較してみました。 行列サイズが4の倍数の時や、行列サイズが小さいとき(おおむね600以下)は、多くの場合でOpenBLAS並みの性能を達成しました。 一方、非常にナイーブな実装であるため、行列サイズ(またはその1/2か1/3)が512の倍数に近くなる場合は大きな性能低下がありました。 また、行列サイズが4の倍数でない時は、行列が十分に大きいときにOpenBLASよりも高い性能とはなりませんでした。

*1:__uint128_tは移植性があるのかと言われると微妙ですが……。