LLVMのk乗和の最適化手法について

clang-O1以上の最適化オプションをつけると、以下のコード(単純に考えるとΘ(n)命令必要そう)からループを取り除き、Θ(1)命令のコードに最適化します。

uint64_t sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0; i < n; ++i ) {
        ret += i;
    }
    return ret;
}

これはclangの驚異的な最適化として挙げられることが多い例ですが、どのような仕組みで行われているのか深く調べたことがなかったので、調べてみました。

環境

以下では、clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04を使っていきます。 なお、この最適化はclangが独立したCコンパイラとなったclang 3.0の時点ですでに実装されていたようです(clang+llvm-3.0-x86_64-linux-Ubuntu-11_04/bin/clangで確認)。

実験には、以下のソースコードを使います。

// square_sum.c
#include <stdio.h>

unsigned long long n;
int main() {
        unsigned long long i, sum = 0;
        for( i = 0; i < n; ++i ) {
                sum += i * i;
        }
        printf( "%llu\n", sum );
}

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/clang -O2 -S -masm=intel square_sum.cとしてコンパイルすると、ループ部分は以下の機械語コードになります。

        mov     rdi, qword ptr [rip + n]
        test    rdi, rdi
        je      .LBB0_1
# %bb.2:
        lea     rax, [rdi - 1]
        lea     rcx, [rdi - 2]
        mul     rcx
        mov     r8, rax
        mov     rsi, rdx
        lea     rcx, [rdi - 3]
        mul     rcx
                                        # kill: def $ecx killed $ecx killed $rcx
        imul    ecx, esi
        add     edx, ecx
        shld    rdx, rax, 63
        movabs  rax, 6148914691236517206
        imul    rax, rdx
        add     rax, rdi
        shld    rsi, r8, 63
        lea     rcx, [rsi + 2*rsi]
        lea     rsi, [rax + rcx]
        add     rsi, -1
        jmp     .LBB0_3
.LBB0_1:
        xor     esi, esi
.LBB0_3:
        # printfの第二引数はrsiで渡すことになっているのでrsiが計算結果のはず

ちゃんとループが消えていることがわかります。

どこで最適化が行われているのかを特定

前提知識

clangC言語から機械語を生成しますが、その裏側では複数のソフトウェアが動いています。 具体的には、以下のようにLLVMに含まれるいくつかのソフトウェアが順に使われます。

  1. C言語フロントエンドとしてのclangで、C言語ソースコードLLVM-IRに変換する
    • 主に言語に依存する最適化に関係します
  2. ミドルエンドであるoptで、LLVM-IRを最適化されたLLVM-IRに変換する
    • 主に言語にも機械語にも依存しない最適化に関係します
  3. バックエンドであるllcで、LLVM-IRを機械語に変換する
    • 主に機械語に依存する最適化に関係します

バックエンドではない

3.の直前のLLVM-IRを手に入れるためには、clang -S -emit-llvmとすればよいです。 clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/clang -O2 -S -emit-llvm square_sum.cとすると、main関数は以下のLLVM-IRになっていました。

; Function Attrs: nofree nounwind uwtable
define dso_local i32 @main() local_unnamed_addr #0 {
  %1 = load i64, i64* @n, align 8, !tbaa !2
  %2 = icmp eq i64 %1, 0
  br i1 %2, label %21, label %3

3:                                                ; preds = %0
  %4 = add i64 %1, -1
  %5 = zext i64 %4 to i65
  %6 = add i64 %1, -2
  %7 = zext i64 %6 to i65
  %8 = mul i65 %5, %7
  %9 = add i64 %1, -3
  %10 = zext i64 %9 to i65
  %11 = mul i65 %8, %10
  %12 = lshr i65 %11, 1
  %13 = trunc i65 %12 to i64
  %14 = mul i64 %13, 6148914691236517206
  %15 = add i64 %1, %14
  %16 = lshr i65 %8, 1
  %17 = trunc i65 %16 to i64
  %18 = mul i64 %17, 3
  %19 = add i64 %15, %18
  %20 = add i64 %19, -1
  br label %21

21:                                               ; preds = %3, %0
  %22 = phi i64 [ 0, %0 ], [ %20, %3 ]
  %23 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([6 x i8], [6 x i8]* @.str, i64 0, i64 0), i64 %22)
  ret i32 0
}

この時点でループが消失しており、x86特有の最適化というわけではなさそうです。

ミドルエンドで行われている

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/clang -O0 -S -emit-llvm square_sum.cとして、フロントエンドでは最適化しないようにします。

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
  %1 = alloca i32, align 4
  %2 = alloca i64, align 8
  %3 = alloca i64, align 8
  store i32 0, i32* %1, align 4
  store i64 0, i64* %3, align 8
  store i64 0, i64* %2, align 8
  br label %4

4:                                                ; preds = %14, %0
  %5 = load i64, i64* %2, align 8
  %6 = load i64, i64* @n, align 8
  %7 = icmp ult i64 %5, %6
  br i1 %7, label %8, label %17

8:                                                ; preds = %4
  %9 = load i64, i64* %2, align 8
  %10 = load i64, i64* %2, align 8
  %11 = mul i64 %9, %10
  %12 = load i64, i64* %3, align 8
  %13 = add i64 %12, %11
  store i64 %13, i64* %3, align 8
  br label %14

14:                                               ; preds = %8
  %15 = load i64, i64* %2, align 8
  %16 = add i64 %15, 1
  store i64 %16, i64* %2, align 8
  br label %4, !llvm.loop !2

17:                                               ; preds = %4
  %18 = load i64, i64* %3, align 8
  %19 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str, i64 0, i64 0), i64 %18)
  %20 = load i32, i32* %1, align 4
  ret i32 %20
}

この時点ではまだループが残っています。

ここで、このLLVM-IRファイル(square_sum.ll)からoptnoneを消します。 optnoneが最適化を強制的に妨げるためです。 二つありますが、一つ目はコメントの中なので、消しても消さなくても同じです。 二つ目を消すのが大事です。

その後、clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt -O2 -S square_sum.llとしてoptコマンドを適用します。 すると、以下のようにループが消えます。

; Function Attrs: nofree noinline nounwind uwtable
define dso_local i32 @main() local_unnamed_addr #0 {
  %1 = load i64, ptr @n, align 8
  %.not = icmp eq i64 %1, 0
  br i1 %.not, label %._crit_edge, label %.lr.ph.preheader

.lr.ph.preheader:                                 ; preds = %0
  %2 = add i64 %1, -1
  %3 = zext i64 %2 to i65
  %4 = add i64 %1, -2
  %5 = zext i64 %4 to i65
  %6 = mul i65 %3, %5
  %7 = add i64 %1, -3
  %8 = zext i64 %7 to i65
  %9 = mul i65 %6, %8
  %10 = lshr i65 %9, 1
  %11 = trunc i65 %10 to i64
  %12 = mul i64 %11, 6148914691236517206
  %13 = add i64 %1, %12
  %14 = lshr i65 %6, 1
  %15 = trunc i65 %14 to i64
  %16 = mul i64 %15, 3
  %17 = add i64 %13, %16
  %18 = add i64 %17, -1
  br label %._crit_edge

._crit_edge:                                      ; preds = %.lr.ph.preheader, %0
  %.0.lcssa = phi i64 [ 0, %0 ], [ %18, %.lr.ph.preheader ]
  %19 = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i64 noundef %.0.lcssa)
  ret i32 0
}

したがって、この最適化はミドルエンドで行われているようです。

Induction Variable Simplificationで行われている

さて、-O2はいくつかの最適化パスの集合体です。 具体的にどのような最適化が行われているかを調べるには、--print-after-allをつけて最適化パスの名前とその時点でのLLVM-IRを出力すればよさそうです(llvm opt optionsでググって出てきたoptの概要 - nothingcosmos wikiに書いてありました)。 これにより、*** IR Dump After Induction Variable Simplification ***の後に出力されるLLVM-IRではループが消失していることがわかります。

Induction Variable Simplificationだけではダメ

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --help | grep 'Induction Variable Simplification'とすると、Induction Variable Simplificationは--indvarsというオプションを指定すれば実行されるということがわかります。 clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --indvars -S square_sum.llとすればよいようです。 ※llvm-16.0.0ではオプションの文法が違い、The `opt -passname` syntax for the new pass manager is not supported,と怒られます。--passes=indvarsなどとすればよいようです。

しかし、これは残念ながらループのままです。

Induction Variable Simplificationの前提条件

最適化パスは他の最適化パスによりLLVM-IRが整えられていることを前提としていることがあります。 どの最適化パスがInduction Variable Simplificationの前提条件となっているかを調べていきます。

少なくとも、-O2と同じ順番でInduction Variable Simplificationまで最適化パスを実行すれば同じ結果が得られるはずです。 実行されているパスの一覧は、以下のようにして得られます。

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt -S -O2 --print-after-all square_sum.ll 2>&1 | grep '\*\*\*'`

以下のような出力を得ます。

*** IR Dump After Module Verifier ***
*** IR Dump After Instrument function entry/exit with calls to e.g. mcount() (pre inlining) ***
*** IR Dump After Simplify the CFG ***
*** IR Dump After SROA ***
(以下略。計121行)

これに対するオプションを調べていきます。 対応するオプションを以下のようにして得ます。

for i in $(seq 121); do clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --help 2>&1 | grep -- "- $(clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt square_sum.ll -S -O2 --print-after-all 2>&1 | grep '\*\*\*' | sed 's/.*After \(.*\) \*\*\*/\1/' | head -n $i | tail -n 1)\$" | awk '{print$1}'; done

llvm-16.0.0では末尾に親切な情報がついてくるため、このコマンドではうまくいきません。

--verify
--ee-instrument
--simplifycfg
--early-cse
--lower-expect
--annotation2metadata
--forceattrs
--inferattrs
--ipsccp
--called-value-propagation
--globalopt
--mem2reg
--deadargelim
--instcombine
--simplifycfg
--prune-eh
--inline
--openmpopt
--function-attrs
--early-cse-memssa
--jump-threading
--correlated-propagation
--simplifycfg
--instcombine
--libcalls-shrinkwrap
--tailcallelim
--simplifycfg
--reassociate
--loop-simplify
--lcssa-verification
--lcssa
--loop-rotate
--licm
--loop-unswitch
--simplifycfg
--instcombine
--loop-simplify
--lcssa-verification
--lcssa
--loop-idiom
--indvars
(以下略)

後はこれを適当に消していきます。すると、以下のように二つの最適化パスだけでうまくいくことがわかります。

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --licm --indvars -S square_sum.ll

--licmは、Loop Invariant Code Motionです。 Loop Invariantはループ不変条件を意味しますが、そんなに大層なことはやっておらず、単にループ中で変わらない変数の計算をループの外に追い出す最適化がかかるだけのようです。 元のコードはループの終了条件にグローバル変数が使われており、毎回ロードが必要だったのが、Induction Variable Simplificationで最適化がかからなかった原因のようです。

実際、以下のコードに変えてみると、clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --sroa --indvars -S square_sum_with_end.llでうまくいくことがわかります(optnoneを消すのをお忘れなく)。 ここで、--sroaはローカル変数endをスタックではなくレジスタに割り付けるために指定しています。

// square_sum_with_end.c
#include <stdio.h>

unsigned long long n;
int main() {
        unsigned long long i, sum = 0, end = n;
        for( i = 0; i < end; ++i ) {
                sum += i * i;
        }
        printf( "%llu\n", sum );
}

これらのことから、ループの終了条件がループ中で変わらないとすぐにわかる(=レジスタに割り付けられた)変数であることが前提条件となっていることが推測できます。

Induction Variable Simplificationの仕組み

ソースコード上の場所を特定

ここからはLLVMソースコードを見ていきます。 かなり昔から行われている最適化なので、かなり安定していると考えられます。 今のバージョンのソースコードを見てもそう違いはないでしょう。

GitHub - llvm/llvm-projectでindvarsと検索すると、llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cppが最適化パスのソースコードのようです。 2000行以上あって読むのが大変ですが、bool IndVarSimplify::run(Loop *L)が本体なので、コメントをたよりにそこを見ていきます。 そうすると、rewriteLoopExitValuesというのがいかにも怪しそうだとわかります。

rewriteLoopExitValuesググるLoopUtils.cppの中にある関数のようです。 1315行目からを眺めていくと、1425行目ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE);でループが終わるときの値を計算しているとみて良さそうです。

evaluateAtIterationググるllvm::SCEVAddRecExprというクラスメンバ関数のようです。 ドキュメントを見てみると、isAffineisQuadratic*1などの誘導変数最適化関連っぽいメンバ関数があり、最適化がここで行われている期待が持てます。 肝心のevaluateAtIteration関数は978行目にあって、BinomialCoefficient関数を呼び出しています。 BinomialCoefficient(It, K, SE, ResultTy)関数は、863行目のコメントにあるように、 \frac{It(It-1)\cdots(It-K+1)}{K!}を計算するコードを生成します。 これは、コンビネーションの記法を用いれば、 {}_{It}C_{K}です。

LLVM-IRを確認

clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --sroa --indvars -S square_sum_with_end.llをやってみると、以下のようになります。

define dso_local i32 @main() #0 {
  %1 = load i64, i64* @n, align 8
  %2 = zext i64 %1 to i65
  %3 = add i64 %1, -1
  %4 = zext i64 %3 to i65
  %5 = mul i65 %2, %4
  %6 = add i64 %1, -2
  %7 = zext i64 %6 to i65
  %8 = mul i65 %5, %7
  %9 = lshr i65 %8, 1
  %10 = trunc i65 %9 to i64
  %11 = mul i65 %2, %4
  %12 = lshr i65 %11, 1
  %13 = trunc i65 %12 to i64
  br label %14

14:                                               ; preds = %16, %0
  br i1 false, label %15, label %17

15:                                               ; preds = %14
  br label %16

16:                                               ; preds = %15
  br label %14, !llvm.loop !2

17:                                               ; preds = %14
  %18 = mul i64 %10, 6148914691236517206
  %19 = add i64 %18, %13
  %20 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str, i64 0, i64 0), i64 %19)
  ret i32 0
}

計算過程を平易に記述してみると、

  1. nをロードする(%1)
  2. n - 1を計算する(%3)
  3. n * (n - 1)を65bit整数として計算する(%5)
  4. n - 2を計算する(%7)
  5. n * (n - 1) * (n - 2)を65bit整数として計算する(%8)
  6. n * (n - 1) * (n - 2) / 2を計算する。64bit整数に収まるはず(%10)
  7. n * (n - 1)を65bit整数として計算する(%11。%5と同じなのにもう一度計算しているのは、共通部分式最適化を施す前の自動生成されたコードそのままであるため)
  8. n * (n - 1) / 2を計算する。64bit整数に収まるはず(%13)
  9. n * (n - 1) * (n - 2) / 3を定数乗算テクニックで計算する(%18)
    • 3の倍数を2/3倍するには、0x5555555555555556を掛けます。この定数は \frac{2^{64}}3+\frac23で、被乗数が3の倍数であれば最初の項は \mod 2^{64}で消えるので、2/3倍できます。
  10. n * (n - 1) * (n - 2) / 3 + n * (n - 1) / 2を計算する。これが求めるものである(%19)

となります。 つまり、 2\,{}_{n}C_{3} + {}_{n}C_{2}を計算するという非常に単純な仕組みです。 この係数21がどこから出てきたのかを、次に説明します。

ループ誘導変数最適化の仕組み

ループ誘導変数最適化の基礎

ループ誘導変数最適化は、典型的な演算強度低減最適化です。 例えば、以下のコードを考えます。

uint32_t sum( uint32_t* arr, size_t n ) {
    uint32_t ret = 0;
    for( size_t i = 0; i < n; ++i ) {
        ret += arr[i];
    }
    return ret;
}

これを素直にコンパイルすると、以下のようになるはずです。

uint32_t sum( uint32_t* arr, size_t n ) {
    uint32_t ret = 0;
    for( size_t i = 0; i < n; ++i ) {
        char* p = (char*)arr + i * 4;
        uint32_t val = *(uint32_t*)p;
        ret += val;
    }
    return ret;
}

乗算が必要になっている点に注意します。 x86ではlea命令があるのであまり気になりませんが、乗算は一般に高コストであるので、可能であれば避けたいです。 そこで、コンパイラは以下のように最適化します。

uint32_t sum( uint32_t* arr, size_t n ) {
    uint32_t ret = 0;
    char* p = (char*)arr;
    for( size_t i = 0; i < n; ++i, p += 4 ) {
        uint32_t val = *(uint32_t*)p;
        ret += val;
    }
    return ret;
}

乗算がなくなり、+= 4という加算になりました。 このような最適化を(ループ誘導変数に関連する)演算強度低減最適化と言います。

発展的なループ誘導変数最適化

この考え方を発展させます。 以下のコードについて考えます。

uint64_t square_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0; i < n; ++i ) {
        ret += i * i;
    }
    return ret;
}

これを誘導変数最適化すると以下のようになります。

uint64_t sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, square = 0; i < n; square += 2 * i + 1, ++i ) {
        ret += square;
    }
    return ret;
}

本物の乗算から定数乗算へと、演算強度が低減されたことがわかります。

ここに現れる定数21こそが、 2\,{}_{n}C_{3} + {}_{n}C_{2}の係数だったのです。 この係数はどう求めるかというと、ループ誘導変数同士の乗算の公式を使っています(ScalarEvolution.cppの3311行目から3314行目)。 Clang の k 乗和の最適化を眺める - えびちゃんの日記では競プロerらしく(?)第二種 Stirling 数との関連を指摘&高速に計算する方法を考察していますが、そんなに高尚なものではないです。 愚直な計算なのでオーダーはおそらくΘ(k3)ですが、そんなに大きな次数の多項式が現れることはありえなさそうです。 そもそも、そういう場合は係数がオーバーフローしてあきらめるので、遅さが問題となることはありません。 逆に、この公式を使うと有理数演算が出てこないのがうれしいのだと思います。

ループ誘導変数同士の乗算の公式の説明の前に、まずループ誘導変数を定式化します。

ループ誘導変数の定式化

三乗和のコードを手で最適化して考えてみる

実際のコードで出てくることはほぼありませんが、二乗だとまだ一般化が見えづらいため、以下のように三乗が出てくるループを考えます。

uint64_t cube_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, square = 0; i < n; ++i ) {
        ret += i * i * i;
    }
    return ret;
}

 (i+1)^3 = i^3 + 3i^2 + 3i +1であることから、これを以下のように変形できます。

uint64_t cube_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, cube = 0; i < n; cube += 3 * i * i + 3 * i + 1, ++i ) {
        ret += cube;
    }
    return ret;
}

さて、ここでi * iが出てきました。これに対しても再帰的に適用することで、以下のようになります。

uint64_t cub_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, square = 0, cube = 0; i < n; cube += 3 * square + 3 * i + 1, square += 2 * i + 1, ++i ) {
        ret += cube;
    }
    return ret;
}

3 * square3 * iに対して演算強度低減最適化をしましょう。

uint64_t cube_sum(uint64_t n) {
    uint64_t ret = 0;
    for( uint64_t i = 0, three_i = 0, three_square = 0, cube = 0; i < n; cube += three_square + three_i + 1, three_square += 6 * i + 3, three_i += 3, ++i ) {
        ret += cube;
    }
    return ret;
}

three_square + three_iはまとめられそうです。

uint64_t cube_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, tmp = 0, cube = 0; i < n; cube += tmp + 1, tmp += 6 * i + 6, ++i ) {
        ret += cube;
    }
    return ret;
}

6 * iも演算強度低減最適化しましょう。

uint64_t cube_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, six_i = 0, tmp = 0, cube = 0; i < n; cube += tmp + 1, tmp += six_i + 6, six_i += 6, ++i ) {
        ret += cube;
    }
    return ret;
}

ここで現れる係数(1, 6, 6)が求めるものです。 つまり、retの最終値 {}_nC_2 + 6\,{}_nC_3 + 6\,{}_nC_4だということです。 実際、optの出力はこれを計算するものになっていました。 この式を展開すると  \left(-\frac12n+\frac12n^2\right)+\left(2n-3n^2+n^3\right)+\left(-\frac32n+\frac{11}4n^2-\frac32n^3+\frac14n^4\right)になっていて、  \sum_{i=0}^{n-1} i^3= \frac14n^2 - \frac12n^3 + \frac14n^4と一致しています。

ここで出てきた誘導変数は、cubetmpsix_iでした。 これらのいずれも、cube += tmp + 1tmp += six_i + 6six_i += 6のように、更新の式が(他の誘導変数または0)+(定数)を加算する形になっています。 そこで、Aが誘導変数であるということを、次のように定義しましょう。

  • Aが誘導変数であるとは、他の誘導変数Bを用いて、更新の式がA += B + cのようになることである
    • ここで、増分cはループ不変式である必要がある
    • B0でもよい

※これは正確には加法的な誘導変数に関する定義です。指数関数を乗算に落とす演算強度最適化を行うと、更新の式がA *= B * cB1でもよい)のような乗法的な誘導変数が出現します。LLVMはこれを取り扱えず、 \sum_{i=0}^{n-1}2^iを計算するコードはSIMDを駆使して頑張る命令列になります。

このように定義すると、Aを記述するのに必要な情報は、Aの初期値と、各誘導変数の更新式に現れる増分cだけであり、簡潔に記述することができます。 LLVMの流儀(SCEVAddRecExprの出力フォーマット)では、iのことを{0,+,1}(初期値が0で1ずつ増える)、six_iのことを{0,+,6}(初期値が0で6ずつ増える)、tmpのことを{0,+,6,+,6}(初期値が0で、6+「初期値が0で6ずつ増える値」ずつ増える)、cubeのことを{0,+,1,+,6,+,6}(初期値が0で、1+『初期値が0で6+「初期値が0で6ずつ増える値」ずつ増える値』ずつ増える)、retのことを{0,+,0,+,1,+,6,+,6}(初期値が0で、0+【初期値が0で1+『初期値が0で6+「初期値が0で6ずつ増える値」ずつ増える値』ずつ増える値】ずつ増える)と表記しているようです。 なお、途中に出てくる誘導変数の初期値はすべて0に正規化出来ます。一つ外側の誘導変数の増分cを変えればよいからです。

実際の計算手順(ループ誘導変数同士の乗算の公式を使う方法)

上記やり方は、「three_square + three_iはまとめられそう」のようにあまり系統立った方法ではありませんでした。 実はもう一通りやり方があって、それはまずi * iだけを誘導変数とみなす方法です。 この方法で最適化すると、まず一回目では以下のようになります。

uint64_t cube_sum( uint64_t n ) {
    uint64_t ret = 0;
    for( uint64_t i = 0, square = 0, cube = 0; i < n; square += 2 * i + 1, ++i ) {
        ret += square * i;
    }
    return ret;
}

ここで、squareLLVMの記法でいうと{0,+,1,+,2}で、i{0,+,1}です。 問題は、{0,+,1,+,2}{0,+,1}の乗算をどう取り扱うかです。 ようするに{a,+,b,+,c} a\,{}_iC_0 + b\,{}_iC_1 + c\,{}_iC_2なので、この形同士の乗算を考えればいいわけですが、これには公式があるようです。 ソースコードの方が信頼できるのでScalarEvolution.cppの3340行目から3364行目を眺めると、以下のような計算をしているようです。

// X = { 0, 1, 2 };
// Y = { 0, 1 };
ret = {};
for( x = 0; x < X.size() + Y.size() - 1; ++x ) {
  sum = 0;
  for( y = x; y < 2*x + 1; ++y ) {
    Coeff1 = Choose(x, 2*x - y);
    for( z = max(y-x, y-X.size()+1), z < min(x+1, Y.size()); ++z )
      Coeff2 = Choose(2*x - y, x - z);
      sum += Coeff1*Coeff2*X[y-z]*Y[z];
    }
  }
  ret.push_back(sum);
}

ということはScalarEvolution.cppの3311行目から3314行目choose(x, 2x)*choose(2x-y, x-z)と書いてあるところは間違い(choose(x-1, 2x-y-1)*choose(2x-y-1, x-z)が正しい?)のようです。

これを動かすと、以下のようになります。

x = 0
  y = 0
    z = 0
      Choose(0,0)Choose(0,0)X[0]Y[0]
x = 1
  y = 1
    z = 0
      Choose(1,1)Choose(1,1)X[1]Y[0]
    z = 1
      Choose(1,1)Choose(1,0)X[0]Y[1]
  y = 2
    z = 1
      Choose(1,0)Choose(0,0)X[1]Y[1]
x = 2
  y = 2
    z = 0
      Choose(2,2)Choose(2,2)X[2]Y[0]
    z = 1
      Choose(2,2)Choose(2,1)X[1]Y[1]
  y = 3
    z = 1
      Choose(2,1)Choose(1,1)X[2]Y[1]
  y = 4
    満たすzなし
x = 3
  y = 3
    z = 1
      Choose(3,3)Choose(3,2)X[2]Y[1]
  y = 4
    満たすzなし
  y = 5
    満たすzなし

これを計算すると、一つ目の係数は0、二つ目の係数はChoose(1,0)Choose(0,0)X[1]Y[1]=1、三つ目の係数はChoose(2,2)Choose(2,1)X[1]Y[1]+Choose(2,1)Choose(1,1)X[2]Y[1]=6、四つ目の係数はChoose(3,3)Choose(3,2)X[2]Y[1]=6、となり、無事に{0,+,1,+,6,+,6}が導出できました。

retcubeの総和なので、もう一レベル上がって{0,+,0,+,1,+,6,+,6}になります。これは  0\times{}_iC_0 + 0\times{}_iC_1 + 1\times{}_iC_2+ 6\times{}_iC_3+ 6\times{}_iC_4です。 ループ脱出時にはi == nになっているので、  0\times{}_nC_0 + 0\times{}_nC_1 + 1\times{}_nC_2+ 6\times{}_nC_3+ 6\times{}_nC_4になっているはずです。 これで、結論に達することができました。

clangの出力が汚い理由

上では、clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --sroa --indvars -S square_sum_with_end.llという、本質部分だけを抜き出した最適化により得られたきれいなコードを確認しました。 しかし普通にclang -O2としたときに出力される機械語コードのアルゴリズムは、これより汚いです。 同様にオプションを一つずつ削ることで、アルゴリズムが汚くなる原因が--loop-rotateにあることがわかります。 この最適化パスは、while文をdo文に変えるような最適化を行うようです。 つまり、

// whileっぽい方式
    unsigned long long sum = 0, i = 0, end = n;
loop:
    if( i >= end ) { goto loop_end; }
    sum += i;
    ++i;
    goto loop;
loop_end:
    printf( "%llu\n", sum );

とするとループ一周当たりに分岐が二回(最後以外成立しないif文と、loopに戻るgoto文)ですが、

// doっぽい方式
    unsigned long long sum = 0, i = 0, end = n;
    if( end == 0 ) { goto loop_end; }
loop:
    sum += i * i;
    ++i;
    if( i < end ) { goto loop; }
loop_end:
    printf( "%llu\n", sum );

のようにすれば、ループ脱出の条件分岐とloopに戻る分岐を兼ねられて、ループ一周当たりに分岐が一回になります。 この最適化をLoop Rotation最適化というようです。

if( i >= end ) goto loop_end; sum += i * i; ++i;                               goto loop;

                              sum += i * i; ++i; if( i >= end ) goto loop_end; goto loop;

に変わる(ループ本体が文単位で見て左ローテートになっている)あたりが、Rotationと呼ばれる理由なのでしょう。

この最適化は分岐命令の実行数を減らす点では有用ですが、どうもInduction Variable Simplificationとの相性が悪いようです。 clang+llvm-12.0.1-x86_64-linux-gnu-ubuntu-16.04/bin/opt --loop-rotate --sroa --indvars -S square_sum_with_end.llとすると、条件がfalseの分岐が取り残されます。 これを見るに、必ず一回は実行されるループ構造になっているのが原因なのか、最後の一周が取り残されているような雰囲気があります。 実際、optの出力が計算しているのは  {}_{n-1}C_{1} + 3\,{}_{n-1}C_{2} + 2\,{}_{n-1}C_{3}で、一周分少なくなっていることがわかります。 おそらく、{0,+,0,+,1,+,2}(ループの最終周が始まるときのretの値)と{0,+,1,+,2}(ループの最終周におけるi * iの値)を足して{0,+,1,+,3,+,2}になったのでしょう(誘導変数の加算は要素ごとの加算でできます)。

まとめ

  • clangは誘導変数最適化の枠組みで、 \sum_{i=0}^{n-1} i = \frac{n(n-1)}2のような最適化を行っている
    • 具体的には、optの中のInduction Variable Simplification最適化パスで計算されている
  • clangは、SCEVAddRecExprという形で加法的な誘導変数を高度に解析している
    • 加法的な誘導変数であれば、それが k次の誘導変数であってもループの各回や終了時点の値を直接算出できる
      • cube += tmp + 1tmp += six_i + 6six_i += 6のように、誘導変数の連鎖で表す
      •  {}_{i}C_{\kappa} (0\le\kappa\le k)を基底とした表現で持っているのと実質的に同じ
    • 誘導変数同士の演算結果がどのような誘導変数になるかが計算できる
      • 加算は自明(係数の要素ごと加算でよい)
      • 乗算は有理数演算が不要な公式がある
  • 普通に-O2などの最適化オプションを使う場合、Rotate Loop最適化パスが先に実行されるためにきれいなコードとならない
    •  {}_{n-1}C_{\kappa} (0\le\kappa\le k)を基底とした表現で計算されてしまう

関連情報

*1:i*i<5みたいな条件式の時、二次方程式を解けばiの範囲を特定することができます。isQuadratic関数は、このために使われるようです。

Yosysを使ってみる(その1)

Yosysというオープンソースの論理合成ツール(レジスタ転送レベルからゲートレベルに変換する一種のコンパイラ)があるので使ってみました。

github.com

インストール

公式のREADMEとだいたい同じです。Makefileを書き換えてインストール場所を変更すれば、インストールに特権は必要ありません。

sudo apt update
sudo apt-get install build-essential clang bison flex libreadline-dev gawk tcl-dev libffi-dev git graphviz xdot pkg-config python3 libboost-system-dev libboost-python-dev libboost-filesystem-dev zlib1g-dev
mkdir git
cd git
git clone https://github.com/YosysHQ/yosys
cd yosys/
vi Makefile
make -j
make install

ここで、Makefileは、PREFIX ?= /home/lpha/yosysのように書き換えました。 make -jは、12th Gen Intel(R) Core(TM) i9-12900Kで2分くらいかかりました。 99%になったくらいが折り返し地点です。

使ってみる

チュートリアル

Yosys Open SYnthesis Suite :: Screenshots を見ながら試してみます。 まず、counter.vcmos_cells.libをダウンロードして実験用ディレクトリに置いておきます。 その後、以下のように入力します(各行について、先頭から空白文字まではプロンプトなので入力しません)。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> dfflibmap -liberty cmos_cells.lib
yosys> abc -liberty cmos_cells_plus.lib

4bitの符号なし整数に1を足す回路は以下のように合成できるようです。

ABC RESULTS:              NAND cells:        7
ABC RESULTS:               NOT cells:        4
ABC RESULTS:               NOR cells:        9

ライブラリを書き換えてみる

cmos_cells.libをコピーしたcmos_cells_2.libに、以下の記述を書き加えてみます。

  cell(XOR) {
    area: 4;
    pin(A) { direction: input; }
    pin(B) { direction: input; }
    pin(Y) { direction: output;
             function: "(A*B+A'*B')"; }
  }

さきほどと同じ手順を踏むと、以下の結果を得ます。

ABC RESULTS:               NOT cells:        3
ABC RESULTS:              NAND cells:        3
ABC RESULTS:               XOR cells:        3
ABC RESULTS:               NOR cells:        6

NANDが4つ、NORが3つ、NOTが1つ、それぞれ減った代わりにXORが3つ増えたようです。 得したのかよくわかりません。

XORのコストを上げたら使われなくなるのかな、と思ってarea: 40;などとしてみたら、以下のようになりました。

ABC RESULTS:               NOT cells:        4
ABC RESULTS:              NAND cells:        6
ABC RESULTS:               XOR cells:        1
ABC RESULTS:               NOR cells:        7

XORの使用がひかえめになりました。 でも、XORはNAND四つで作れるので、NANDの10倍のコストに設定したのに使われるのは変です。 面積を最小化しているわけではないのでしょうか。

ゲートカウント

yosys gate countとかでググると、以下のページが出てきます。

github.com

以下のように入力すれば、CMOSゲートに合成してくれるようです。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> abc -g cmos

出力は以下のようになりました。

ABC RESULTS:               NOT cells:        2
ABC RESULTS:              NAND cells:        1
ABC RESULTS:              AOI3 cells:        1
ABC RESULTS:               NOR cells:        2
ABC RESULTS:               XOR cells:        2

AOI3のような、CMOSらしさあふれるものが使われていることがわかります(AOI3は複合ゲートの一種で、~((a&b)|c)が6トランジスタで作れるというやつです)。

トランジスタカウント

yosys transistor countとかでググると、以下のページが出てきます。

stackoverflow.com

以下のように入力すれば、トランジスタ数を算出してくれるようです。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> abc -g cmos
yosys> stat -tech cmos

すると、以下の出力を得ます。

12. Printing statistics.

=== counter ===

   Number of wires:                 19
   Number of wire bits:             31
   Number of public wires:           4
   Number of public wire bits:       7
   Number of memories:               0
   Number of memory bits:            0
   Number of processes:              0
   Number of cells:                 12
     $_AOI3_                         1
     $_NAND_                         1
     $_NOR_                          2
     $_NOT_                          2
     $_SDFFE_PP0P_                   4
     $_XOR_                          2

   Estimated number of transistors:         46+

+がついてしまっているのは、フリップフロップトランジスタ数がわからないことによるもののようです(stat.ccの215~218行目261行目)。 トランジスタ数は、cost.hの53行目~68行目に定義されているものが使われていそうです。 XORゲートは10トランジスタで作れるはず(XORゲート - Wikipedia)ですが、それ以外はあっていそうです。

なお、ソースコードのどこで定義されているかは、GitHub上でcmosと検索することで発見しました。

面積

dffmap libraryとかでググると、以下のページが出てきます。

tom01h.exblog.jp

どうやらOklahoma State Universityが180nm用のスタンダードセルライブラリを公開しているようです。 osu018_stdcells.lib(243KB)をダウンロードして、実験用ディレクトリに入れます。 これを、cmos_cells.libの代わりに使います。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> dfflibmap -liberty osu018_stdcells.lib
yosys> abc -liberty osu018_stdcells.lib
yosys> stat -liberty osu018_stdcells.lib

出力は以下のようになりました。

12. Printing statistics.

=== counter ===

   Number of wires:                 34
   Number of wire bits:             46
   Number of public wires:           4
   Number of public wire bits:       7
   Number of memories:               0
   Number of memory bits:            0
   Number of processes:              0
   Number of cells:                 17
     AND2X1                          1
     AOI21X1                         3
     DFFPOSX1                        4
     INVX1                           2
     NAND3X1                         1
     NOR2X1                          3
     OAI21X1                         2
     XNOR2X1                         1

   Chip area for module '\counter': 754.000000

スタンダードセルの面積の確認

トランジスタを並べて回路を実現するとき、トランジスタを自由に配置できるとすると検証が大変です。 そこで、よくありそうな回路について検証されたトランジスタ配置を作っておき、それを並べて回路を実現することを考えます。 そのような検証されたトランジスタ配置がスタンダードセルです。 スタンダードセルは並べて使うことが前提なので、回路の外枠が長方形になっていて、かつ高さが固定値です。 高さが固定なので、トランジスタ数の多い回路は幅が広くなります。

osu018のスタンダードセルライブラリは、pMOSトランジスタとnMOSトランジスタのペア当たり、高さ8、幅1、となっていると思われます。

  • NOTゲート:2トランジスタ必要で1列、隣とのスペースに1列、で2列なので面積は16
  • NANDゲートやNORゲート:4トランジスタ必要で2列、隣とのスペースに1列、で3列なので面積は24
  • ANDゲートやORゲート:6トランジスタ必要で3列、隣とのスペースに1列、で4列なので面積は32
  • XORゲート:12トランジスタ必要で6列、隣とのスペースに1列、で7列なので面積は56(やっぱり12トランジスタで作るのが主流なのでしょうか)
  • 半加算器:XORとANDなので18トランジスタ必要で9列、隣とのスペースに1列、で10列なので面積は80
  • 全加算器:負論理で24トランジスタ必要※で12列、反転に4トランジスタ必要で2列、隣とのスペースに1列で15列なので面積は120

※参考文献:https://www.hindawi.com/journals/vlsi/2012/173079/XORゲート - Wikipediaの参考文献として載っていました)

三入力NANDゲートは同様の推論で行くと32になりそうなところ36になっていて謎ですが、他は大体あっていそうです。

次に読む

さきほどのOklahoma State Universityのスタンダードセルライブラリを紹介していたブログの他の記事で、遅延の評価をやっているので、今度はこれをやってみたいと思います。

tom01h.exblog.jp

AVX-512の機能を使ったlogf(x)の実装(その2)

前回(AVX-512の機能を使ったlogf(x)の実装(その1) - よーる)の記事で、AVX-512でベクトル実行することを前提としたlogfの高速実装を作りました。 速度を重視しつつもできる限り精度に気を付けた結果、ほとんどの入力に対して誤差を1.5ULP未満とできました。 ただし、対数関数の定義域に属する2139095039個の単精度浮動小数点数のうち、1502個についてだけは誤差が1.5ULPを超えてしまいました(それでも最大誤差は2.2ULP未満です)。 今回は誤差の原因を詳細に解析し、その誤差の要因を避けることのできる係数を発見することで、最大誤差1.5ULPを保証するlogfの高速実装を作ります。

前回の実装の確認

前回の実装の本体を示します。 補助関数まで示すと冗長になるので省略しました(vgetexpps浮動小数点数を仮想的に*1正規化して指数部を取り出す関数、vgetmantps浮動小数点数を仮想的に正規化して仮数部を取り出す関数、vpermpsは16エントリの表引きを行う関数です)。

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

ようするに、xを仮想的に正規化したときの指数部expo e)とxを仮想的に正規化したときの仮数部から決まる16通りの定数invs \sigma_k = s_k^{-1})を使って、 \log x = e\log2 + \log s_k + \log\frac x{2^es_k} のように計算しているということです。

誤差の要因の確認

x1.0fから十分離れている時

 \left|\log\frac x {2^es}\right| |\log x| より十分小さく、 \log\frac x  {2^es} の計算で生じる誤差が最終結果に与える影響は十分小さいです。 よって素直に計算すれば大きな誤差が発生することはありません(実際、1.5ULPを超える誤差が発生する1502ケースはいずれも  \exp(-2^{-6}) \lt x \lt \exp(2^{-5}) です)。

x1.0fに近いとき

expo0.0fとなります。よって、std::fma( expo, 0x1.62e430p-1f, logs )logsになります。 この時、最終結果の誤差は主に以下の四つの誤差の和からなります。

  1. 最後のstd::fma丸め誤差
  2. logs \log sの差
  3. t丸め誤差
  4. poly \log(1+t)/tの差のt

このうち、一つ目は明らかに避けることができません(0.5ULP)。 二つ目は、前回説明したように、 \sigma_k を適切に選ぶことで十分小さくできます。 三つめは、invs1.0fでない時は普通にやると避けることができません(0.5ULP)。

さて、残る「poly \log(1+t)/t の差のt倍」はどうでしょうか。

poly \log(1+t)/t の差」に最も効いてくるのはpolyの計算の最後の丸め誤差(0.5ULP)です。 この0.5ULPの誤差が厄介で、polyが1より小さければ  2^{-25} なのですが、polyが1より大きいと  2^{-24} になってしまいます。  2^{-25} の誤差のt倍は、t仮数部が2に近いときtに対して0.5ULPであり、 2^{-24} の誤差のt倍は、t仮数部が2に近いときtに対して1.0ULPです。 実際にはこれに加えて計算途中の丸め誤差の影響と近似誤差が追加されます。

 k = 0の時

 t = x - 1 になり、tを求める部分では丸め誤差が発生しません。 polyが1より大きくて丸め誤差が大きくなってしまうのは  x \lt 1 の時で、t仮数部が2に近ければlogf(x)の指数部とtの指数部は同じになります。 その時、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差が0ULP、三つ目の誤差が0ULP、四つ目の誤差が1.0ULP以上、で合わせて1.5ULP以上になります。 すべての最悪ケースを同時に引くことはなさそうですが、それでも1.5ULPを保証するのはかなり厳しそうに見えます。

 k \ne 0の時

logf(x)の指数部に比べてtの指数部が十分小さければ、「x1.0fから十分離れている時」と同様の論理で、素直に計算すれば大きな誤差が発生することがありません。

tの指数部がlogf(x)の指数部と同じになる場合を考えます。 これは例えば、 x = 0.984 付近の時に  \sigma_{15} = 32/31 t = x\sigma_{15} -1 = 1.00748\times2^{-6} \log x = -1.03228\times2^{-6} になるような場合のことです。 すると、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が0.5ULP以上、で合わせて1.5ULP以上になります。 tの指数部が大きいということはつまり多項式近似の端の方なので、近似誤差を含めて1.5ULP未満に抑えることは難しそうです。 さらに言えば、 x = 1.0315 付近の時などはもっとひどくて、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が1.0ULP以上、で合わせて2.0ULP以上になって、どうにもなりません(実際、2.0ULPを超える誤差が発生する5ケースは全てこれです)。

このような問題が起こるのは、 \sigma_k が切り替わる境界がlogfの指数部やtの指数部が切り替わる境界とずれているからです。 それをまとめたのが以下の表です。

本来あるべき境界 実際の境界 問題点
 \sigma_{15} \sigma_0 の境界  (1+2^{-6})s_{15}=0.983887  \frac{63}{64}=0.984375  k = 15 の負担が大
 \sigma_0 \sigma_1 の境界  \exp(2^{-5})=1.031743  \frac{33}{32}=1.03125  k = 1 の負担が大

このうち、 s_{15}がかかわるほうの問題は s_{15}を適切にとることで解消できます。

戦略

以下の二つの戦略により、上記問題を解決します。

  • idxfを求める部分で浮動小数点数加算ではなく浮動小数点数積和演算を用いることで、命令数を増やさずに境界をずらす
  • 多項式近似の係数を工夫して、 \log(1+t)/t \gt 1 となる範囲( t \lt 0)では極力近似誤差が0に近くなるようにする

さらに、上記戦略をうまく働かせるため、 k ごとに  t としてあり得る範囲を調整します。 これは  \sigma_k \left(1+\frac{k}{16}\right)^{-1} から大胆にずらすことで実現します。

終結果に与える最大誤差が以下のようになることを目指して最適化します。

誤差要因  k = 0 の場合  k = 15 または  k = 1 の場合
最後のstd::fma丸め誤差 0.5ULP 0.5ULP
logs \log sの差 0ULP ほぼ0ULP
t丸め誤差 0ULP 0.5ULP
poly \log(1+t)/tの差のt 1.0ULP 0.5ULP

idxfを求める部分の変更

 \sigma_0 \sigma_1 の境界が  \exp(2^{-5})=1.031743 より大きくなり、かつ  \sigma_{15} \sigma_0 の境界が  \exp(-2^{-6})=0.984496 以下となれば良いです。

実際は近似多項式の係数と同時に最適化するのですが、結論から言うと以下のように設定すればうまくいきました。

    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );

このようにすることで、 k = 1 かつ  t \lt 0 の場合で誤差が大きくなる問題を防ぎます。 代わりに  k = 0 の場合に取り扱わなければいけないtの範囲が  -0.016130 \lt t \le 0.03125 から  -0.011012 \lt t \lt 0.036084 になります。 負の側は範囲が狭くなり、正の側は範囲が広くなっています。 負の側は精度の要求が厳しく、正の側は精度の要求が緩いので、それと合致しています。 単に  k = 1 かつ  t \lt 0 の場合の回避だけでない効果をここに求めているようです。

 \sigma_k の選択

以下を満たしていれば意外と自由に取れます。

  •  \sigma_{15}
    • tの上限が  2^{-7} 以下になるようにする
      • 境界変更により  -2^{-6} \lt \log(x) \lt -2^{-7} となる付近を取り扱わないといけないので、  2^{-7} \lt t だとpoly \log(1+t)/t の差のt倍が最終結果に対して0.5ULP以上になってしまう
      • 多項式近似の区間の幅を制約するので、tの上限はなるべく  2^{-7} に近いほうが良い
  •  \sigma_1
    • tの上限が  k = 0 の場合のtの上限以下になるようにする
    • tの下限が  k = 15 の場合のtの下限以下になるようにする

多項式近似の工夫

上記  \sigma_k の選択の下、以下を守れば全体が1.5ULP保証になります。 実際には、ほんの少し守れていなくても、誤差要因全てがそろうことはないので、結果的に1.5ULP保証できる実装が手に入ります。

  • 多項式近似の誤差が  t=-2^{-N}(N=7,8,9,\dots) でほぼ0になるようにする
    • poly \log(1+t)/t の差のt倍」はt仮数部が2に近いときtのULPに対して大きくなるので、 t が二の冪乗の時が重要
    •  k = 0 において、t仮数部が2に近くて t \lt 0ならばlogf(x)のULPはtのULPと同じ
    • poly丸め誤差だけで最終結果に与える影響が1.0ULPあるので、多項式近似誤差が少しでもあると1.0ULPを超えてしまう
  • poly \log(1+t)/t の差のt倍について、
    • 区間の左端  t = L において  2^{-29} 未満になるようにする
      •  k = 0 では使われず、 k = 15 k = 1 で使われる
      • その時  2^{-5} \lt |\log x| \lt 2^{-4}なので、logf(x)に対しての0.5ULPは 2^{-29}
    • 区間の右端  t = R において  2^{-28} 未満になるようにする
      •  k = 0 で使われ、その時  2^{-5} \lt \log x \lt 2^{-4}なので、logf(x)に対しての1.0ULPは 2^{-28}
      •  k = 1 で使われ、その時  2^{-4} \lt \log x \lt 2^{-3}なので、logf(x)に対しての0.5ULPは 2^{-28}
        • ただし、 k = 1 の場合にはtの指数部がlogf(x)の指数部よりも一つ小さくてt丸め誤差が最終結果に与える影響が0.25ULPしかないので実際には問題とならない
    •  t \in [\exp(2^{-N-1})-1, \exp(2^{-N})-1)(N=5,6,7,\dots) において  2^{-N-24} 未満になるようにする
      •  k = 0 において重要で、その区間において  2^{-N-1} \le \log x \lt 2^{-N}なので、logf(x)に対しての1.0ULPは 2^{-N-24}

これを満たす近似多項式は、以下のようになります(限界までは最適化していません)。

    const float C4 = -0x1.fb4ef8p-3f;
    const float C3 =  0x1.556fc4p-2f;
    const float C2 = -0x1.ffffe2p-2f;

    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );

この時、tの値ごとの「t丸め誤差」と「poly \log(1+t)/t の差のt倍」が最終結果に与えるULPでみた影響の最大値は、図1のようになります。

図1: tの値ごとの最終丸め誤差以外の誤差が最終結果に与えるULPでみた影響の最大値のグラフ

なんとか±1.0ULPに収まっています。

最終的な実装のソースコード

// -------------------------------------------------------
//  Copyright 2023 lpha-z
//  https://lpha-z.hatenablog.com/entry/2023/09/03/231500
//  Distributed under the MIT license
//  https://opensource.org/license/mit/
// -------------------------------------------------------

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.0000000p+0f,
0x1.e286920p-1f,
0x1.c726fe0p-1f,
0x1.af35980p-1f,
0x1.99a95e0p-1f,
0x1.861a9e0p-1f,
0x1.746c640p-1f,
0x1.6435820p-1f,
0x1.5564f40p+0f,
0x1.47a8960p+0f,
0x1.3b1c5e0p+0f,
0x1.2f640a0p+0f,
0x1.24958c0p+0f,
0x1.1a813e0p+0f,
0x1.11180c0p+0f,
0x1.04d9b40p+0f,
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.e5b538p-5f,
+0x1.e2118ap-4f,
+0x1.5fb476p-3f,
+0x1.c8b0a8p-3f,
+0x1.166fecp-2f,
+0x1.45eeaap-2f,
+0x1.7383aap-2f,
-0x1.26c4fcp-2f,
-0x1.f96f70p-3f,
-0x1.a97736p-3f,
-0x1.5bd74ap-3f,
-0x1.118fbcp-3f,
-0x1.9387e8p-4f,
-0x1.08c23ep-4f,
-0x1.338588p-6f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );
    if( mant >= 0x1.79c328p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    const float C4 = -0x1.fb1370p-3f;
    const float C3 =  0x1.556f14p-2f;
    const float C2 = -0x1.ffffe2p-2f;
    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );

    return ret;
}

これらの改良により上記実装は先週示した実装と比べて命令数を増やさずに最大誤差を小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

となりました。  2 \le k \le 14 に対する  \sigma_k は適当にとったので、tの絶対値が小さくなるよう範囲の中心にとれば平均誤差はもう少し減らせるかもしれません。 ただ、0.5ULPを超える誤差が頻繁に発生する原因はstd::fma( expo, 0x1.62e430p-1f, logs )の丸めなので、係数を工夫しても0.5ULPを超える誤差が発生する頻度を下げることは難しそうです。 std::fma( expo, 0x1.62e430p-1f, std::fma( poly, t, logs ) )の順番で計算すれば0.5ULPを超える誤差が発生する頻度を3.3%まで下げることができますが、レイテンシが一命令分伸びます。

また、最大の誤差は-1.45944ULP/+1.47702ULPでした。 ULPで見た最大誤差を1.5ULPより有意に小さくすることは、命令数を増やさずには不可能だと思います。 アルゴリズムを大幅に見直せば回避できる可能性はありますが、おそらく命令数が増えます。

まとめ

logf(x)関数に含まれる誤差の原因を詳細に調べました。 idxfの計算に積和演算を使うことでテーブルの担当領域をずらし、さらに注意深く近似多項式を設計することにより、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。

おまけ:没になった戦略

 \sigma_1 として下位ビットに0が並ぶような浮動小数点数を選択すれば、tを計算するときの丸め誤差を0にすることができます。 しかし、 \sigma_1 として選べる浮動小数点数の選択肢が少なく、 \log s_1浮動小数点数に丸めるときの誤差を小さくすることができません。  \log s_1 の指数部はtの指数部より一つ大きいので、tを計算するときの丸め誤差0.5ULPよりも最終結果に与える影響を小さくしたければ、 \log s_1丸め誤差は0.25ULP以内に収めなければなりません。 これを満たすことができる \sigma_10x1.efp-1くらいしかなく、理想的な値0x1.e2p-1から離れすぎています。 よって、この方法は採用しませんでした。

*1:通常の浮動小数点数には指数部の制限があり非正規化数がありますが、その制限を無視して取り扱うということを意味します。

AVX-512の機能を使ったlogf(x)の実装(その1)

2023年6月に光成滋生さんがAVX-512特有の機能(通常の浮動小数点数演算以外の命令)を使用したlogf(x)のベクトル実装を公開しました(解説記事→AVX-512によるvpermpsを用いたlog(x)の実装)。 具体的には、指数部を取り出すvgetexpps命令、仮数部を取り出すvgetmantps命令、高速なテーブル参照を実現するvpermps命令の三つを使用しています。

これらの機能は非常に強力であり、正しく組み合わせることでlog関数を大幅に高速化することができます。 しかし、このfmathの実装は、AVX-512特有の機能を使わない部分に改善の余地がありそうです。 いろいろな部分を改善していたら、オリジナルとかなり異なるlogf(x)の実装が得られました。

fmathの実装のC++

fmathの実装はアセンブリコードを出力するPythonコードの形で公開されており、読み解くことが困難です(そもそもビルドする方法がわかりませんでした……)。 とりあえず読み解いた感じでは、以下のようなC++コードとほぼ等価な計算を行っているようです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数の翻案です(不正確かもしれません)。

float tbl1[16] = {
std::bit_cast<float>(0x3f783e10), // 32/33
std::bit_cast<float>(0x3f6a0ea1), // 32/35
std::bit_cast<float>(0x3f5d67c9), // 32/37
std::bit_cast<float>(0x3f520d21), // 32/39
std::bit_cast<float>(0x3f47ce0c), // 32/41
std::bit_cast<float>(0x3f3e82fa), // 32/43
std::bit_cast<float>(0x3f360b61), // 32/45
std::bit_cast<float>(0x3f2e4c41), // 32/47
std::bit_cast<float>(0x3f272f05), // 32/49
std::bit_cast<float>(0x3f20a0a1), // 32/51
std::bit_cast<float>(0x3f1a90e8), // 32/53
std::bit_cast<float>(0x3f14f209), // 32/55
std::bit_cast<float>(0x3f0fb824), // 32/57
std::bit_cast<float>(0x3f0ad8f3), // 32/59
std::bit_cast<float>(0x3f064b8a), // 32/61
std::bit_cast<float>(0x3f020821), // 32/63
};
float tbl2[16] = {
std::bit_cast<float>(0xbcfc14c8), // log(32/33)ではなくlog(tbl1[0])
std::bit_cast<float>(0xbdb78694), // 以下同様
std::bit_cast<float>(0xbe14aa96),
std::bit_cast<float>(0xbe4a92d4),
std::bit_cast<float>(0xbe7dc8c6),
std::bit_cast<float>(0xbe974716),
std::bit_cast<float>(0xbeae8ded),
std::bit_cast<float>(0xbec4d19d),
std::bit_cast<float>(0xbeda27bd),
std::bit_cast<float>(0xbeeea34f),
std::bit_cast<float>(0xbf012a95),
std::bit_cast<float>(0xbf0aa61f),
std::bit_cast<float>(0xbf13caf0),
std::bit_cast<float>(0xbf1c9f07),
std::bit_cast<float>(0xbf2527c4),
std::bit_cast<float>(0xbf2d6a01),
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fmath_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float xm1  = x - 1.0f;
    float abs  = std::bit_cast<float>(std::bit_cast<std::uint32_t>(xm1) & 0x7fffffff);
    if( abs < 0.02f ) {
        c = xm1;
        z = 0.0f;
    }

    float ret  = std::fma( std::fma( std::fma( std::fma( -0.25008487f, c, 0.3333955701f ), c, -0.49999999f ), c, 1.0f ), c, z );
    return ret;
}

改良

事前計算点の改良

この実装では -\log\left(1+\frac{k+0.5}{16}\right) (0\le k\lt16)を事前計算したテーブルを用意しています。 0.5ずれた点の値を持っているのはidxfを切り捨てで作っているのに合わせているからです。 しかしその代償としてx1.0f付近(logf(x)0.0f付近になって精度の要求が厳しくなる所)の時に精度が足りなくなっていて、それを解決するために分岐が発生しています。 幸い、アルゴリズムを大きく変える必要があるわけではないので、マスク付き演算で実現できますが、追加コストとして五命令かかっています。

この問題は、素直に \log\left(1+\frac{k}{16}\right) (0\le k\lt16)を事前計算したテーブルを持つことで解決可能です。 そのためにはidexfを最近接丸めで作る必要がありますが、これは「すごく大きな数を足すと仮数部が丸められる」というテクニック(これも光成滋生さんの解説で知りました→高速な倍精度指数関数expの実装)を使えばよいです。 具体的には、以下のようにします。

    float mant = vgetmantps( x ); // 1.0f <= mant < 2.0f で出てくる
    float idxf = mant + 0x1.p+19; // 最近接の1/16の倍数に丸める
    if( mant >= 0x1.f8p+0f ) { // 0x1.f8p+0f <= mant < 2.0f の時
            mant *= 0.5f; // mant が b = 1.0 に近くなるように mant をずらす
            expo += 1.0f; // expo もずらす
    }

二命令減らし、精度のための追加コストを三命令にすることができました。 ただし、このif文を外すと完全におかしな値を計算することになるので、「精度を犠牲にしても高速化したいときはこの範囲を消してください」というコードにはなっていません。

テーブル値の精度を上げる

上で行ったexpomantのずらしは、もっと積極的にやってよいです。 境界を0x1.f8p+0fではなく0x1.78p+0fとして、 8\le k \lt 15用の値を \log\left(1+\frac{k}{16}\right)ではなく \log\left(\frac12+\frac{k}{32}\right)にします。 これによりlogbの絶対値が小さくなって精度が上がるので、全体の精度が向上します。

近似多項式の改良

使われている係数が謎です。 見た感じでは |c| \lt 0.019に対する最良近似多項式っぽいですが、範囲をそのようにする理由がわかりません(範囲は |c| \lt \frac1{33} = 0.030303のはずです)。 何か間違いがあるように思えます。

上の最適化を行った場合のcの範囲である -\frac{0.5}{17} \lt c \lt \frac{0.5}{16}を前提にExcelを使って最良近似多項式を求めたところ、四次の係数は-0x1.fe9d24p-3f、三次の係数は0x1.557ee2p-2f、二次の係数は-0x1.00000cp-1f、とすれば良さそうだということがわかりました。 当該範囲での近似誤差は、絶対誤差 6.8\times10^{-10} = 1.46\times2^{-31}未満です。

テーブル値の精度が高くなるように事前計算点を選ぶ

よく考えると、事前計算に用いる点 \sigma_k \left(1+\frac{k}{16}\right)^{-1}浮動小数点数に丸めたものでなくてもよいはずです。 あまりにも離れた値を使うと |c|が大きくなって多項式近似誤差が大きくなってしまいますが、多少のずれは許容されます。 そこで、 \sigma_kとして \left(1+\frac{k}{16}\right)^{-1}にそれなりに近い浮動小数点数のうち \log \sigma_k浮動小数点数に丸めたときに誤差が十分小さくなるものを採用しましょう。

精度的に重要なのは、 \sigma_{15}の修正のようです。 ここで、 \sigma_{15}として適当な値は \frac{32}{31}付近には全くありません。 これは \left.\frac{d}{dx}\log(x)\right|_{x=\frac{32}{31}}=\frac{31}{32}でありこの値が二進浮動小数点数で正確に表せる数であることによるもので、この付近では隣の浮動小数点数におけるlogの返り値も下位ビット(=丸め誤差)がほぼ同じになります。 そのため、 \sigma_{15}0x1.084210p+0から0x1.084550p+0へとかなり大きくずらす必要がありました。

なお、この考え方は元のソースコードの精度を上げるif文がないバージョン(0.5ずれた点でテーブルを作るので命令数が少ない代わりに精度が低い方法)にも適用可能です。 tbl1[0] = 0x1.f081dcp-1ftbl2[0] = -0x1.f76c56p-6f(真の値は-0x1.f76c55ffa3p-6付近なので丸め誤差は0.00071ULP)を使うことで |\log(x)| \lt 2^{-6}での最大絶対誤差を 3.4\times10^{-9} = 1.79 \times2^{-29}未満にすることができます。

改良版のソースコード

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.000000p+0f, // 16/16
0x1.e1e1e2p-1f, // 16/17
0x1.c71c72p-1f, // 16/18
0x1.af286cp-1f, // 16/19
0x1.99999ap-1f, // 16/20
0x1.861862p-1f, // 16/21
0x1.745d18p-1f, // 16/22
0x1.642c86p-1f, // 16/23
0x1.555556p+0f, // 32/24
0x1.47ae14p+0f, // 32/25
0x1.3b13b2p+0f, // 32/26
0x1.2f684cp+0f, // 32/27
0x1.24924ap+0f, // 32/28
0x1.1a7b96p+0f, // 32/29
0x1.111112p+0f, // 32/30
0x1.084550p+0f, // ~32/31
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.f0a30ap-5f,
+0x1.e27074p-4f,
+0x1.5ff306p-3f,
+0x1.c8ff7ap-3f,
+0x1.1675cap-2f,
+0x1.4618bap-2f,
+0x1.739d7ep-2f,
-0x1.269624p-2f,
-0x1.f991c4p-3f,
-0x1.a93ed8p-3f,
-0x1.5bf408p-3f,
-0x1.1178eep-3f,
-0x1.9335e4p-4f,
-0x1.08599ap-4f,
-0x1.047a88p-5f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

これらの改良により上記実装はfmathの実装よりも二命令減らしつつ最大誤差も小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

であり、最大の誤差は-1.925ULP/+2.122ULPでした。

この結果を見ると、ほとんどの場合は真値に最近接の浮動小数点数かその前後の浮動小数点数以外を返しており、それ以外の浮動小数点数を返す(≒1.5ULPを超える誤差が発生する)ことはほとんどないことがわかります。 誤差の原因を詳細に調べて各所の係数を工夫したところ、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。 ただ、その説明は長くなるので、今回の記事はここまでにします。

おまけ:速度重視版のソースコードの改良版

 |\log(x)| \lt 2^{-6}での絶対誤差は 3.4\times10^{-9} = 1.79 \times2^{-29}未満となります。 その外側では-2.289ULP/+1.157ULPです。 まともに最適化していないので、誤差はもう少し小さくできそうです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数をC++に翻案したものをベースとし、
// テーブル値と多項式の係数を変更することで精度を改善したものです。

float tbl1[16] = {
0x1.f081dcp-1f,
0x1.d42b36p-1f,
0x1.badefep-1f,
0x1.a43a42p-1f,
0x1.8fa128p-1f,
0x1.7c91dcp-1f,
0x1.6bb3b6p-1f,
0x1.5c7e0ap-1f,
0x1.4e9076p-1f,
0x1.40c2dep-1f,
0x1.34bd30p-1f,
0x1.29e50ap-1f,
0x1.1f3f34p-1f,
0x1.16067ap-1f,
0x1.0c996cp-1f,
0x1.048122p-1f,
};
float tbl2[16] = {
-0x1.f76c56p-6f,
-0x1.6e9312p-4f,
-0x1.290ddap-3f,
-0x1.9489aep-3f,
-0x1.fb779ap-3f,
-0x1.2fc65cp-2f,
-0x1.5e3292p-2f,
-0x1.89f0fep-2f,
-0x1.b3b51ap-2f,
-0x1.ded9ccp-2f,
-0x1.02fbeep-1f,
-0x1.154a94p-1f,
-0x1.27ed54p-1f,
-0x1.38a234p-1f,
-0x1.4a4b10p-1f,
-0x1.59f5fap-1f,
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fast_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float ret  = std::fma( std::fma( std::fma( std::fma( -0x1.0049acp-2f, c, 0x1.557f3ep-2f ), c, -0x1.fffff8p-2f ), c, 1.0f ), c, z );
    return ret;
}

中点の正しい計算方法

二つの浮動小数点数 a, bについて、中点の数学的な値は \frac{a+b}2です。 以下の記事に触発されたので、この中点を浮動小数点数に正しく丸めた値を得る方法について考えます。

中点はどうやって計算するべきか - kashiの日記

当該記事によれば、この問題は意外と難しく、分岐なしに浮動小数点数演算だけで正しく求める方法はまだ知られていないようです。

何が難しいのか

足したらオーバーフローする

まず、もっとも単純に思いつくのは(a+b)*0.5と計算する方法でしょう。 しかしこれは、a+bの時点でオーバーフローする可能性があるので、中点を正しく求めることができません。

引いてもオーバーフローする

高精度に計算する方法として、a + (b-a)*0.5という方法が紹介されることもあります。 しかし、今度はb-aがオーバーフローする可能性があります。 それは、baの符号が異なり、それぞれ絶対値がすごく大きい場合です。 まぁそんな二数の中点を正確に求めたいということはなさそうですが、正しいアルゴリズムではないという点が気になります。

アンダーフローした際の丸め誤差

では、a*0.5 + b*0.5a + (b*0.5 - a*0.5)とするのはどうでしょうか。 これらはオーバーフローの可能性はありませんが、今度は逆にアンダーフローの可能性が出てきます。 通常、二進浮動小数点数の0.5倍は丸め誤差なしに正確に表せますが、結果がアンダーフローする場合は丸め誤差が生じる可能性があります。 最終結果に与える影響が同じ方向の丸め誤差が二つ以上の演算で発生した場合、正しくない答えが求まることがあります。

正しく求める方法

以下では、実装を検証しやすいようにfloatを使います(double用のmidpoint関数は入力のパターンが 2^{128}通りあって検証が困難です)。 floatの特徴はほぼ使っていないので、簡単にdoubleに置き換えることができます。

作戦

std::fma( a, 0.5f, b*0.5f )がほとんどの場合に正しい値を返すことに注目します。 これが正しくない丸めを行う必要条件は、b*0.5fの計算で丸め誤差が発生することです。 この丸め誤差に由来する真の中点とstd::fma( a, 0.5f, b*0.5f )の差の二倍hを算出することを考えます。

使用テクニック

double-doubleのように、実数を二つの浮動小数点数 x, yの組で表現することを考えます。 しかし、double-doubleのように x + yを表現するのではなく、 x + 0.5yを表現することを考えます。

この表現を使うと、任意の浮動小数点数 aについて、 0.5aが表現できます。  x=0, y=aとすればよいのでこれ自体は自明です。 重要なのは、以下のTwoProduct類似の手順で正規化することができることです。

std::pair<float, float> Half( float a ) {
    float hi = a * 0.5f;
    float lo = std::fma( hi, -2.0f, a );
    return { hi, lo };
}

この時、lo0.0fFLT_TRUE_MIN-FLT_TRUE_MINになります(丸めモードが何であってもそうなります)。

このように0.5倍を無誤差で扱う方法を導入します。

概略

まず、先のHalf関数を用いて、次の e, fを求めます。

auto [e, f] = Half(b)

つづいて、仮の中点値 m'を求めます。

float m_ = std::fma(a, 0.5f, e)

 m'と真の中点の差の二倍 hを計算します。

float h = f ? std::fma( m_, -2.0f, a ) + b : 0.0f;

最後に、 m'にそれを足し合わせ、最終結果を求めます。

float ret = std::fma( h, 0.5f, m_ );

もちろん、これでは分岐が発生しているのでダメです。 しかし、f0.0fFLT_TRUE_MIN-FLT_TRUE_MINしかありえないことを思い出します。 これを利用すると、以下のようにして分岐を消すことができます。

std::fesetround(FE_TOWARDZERO);
float half_or_zero = std::fma( f, f, -FLT_MIN/FLT_TRUE_MIN ) + FLT_MIN/FLT_TRUE_MIN;
std::fesetround(FE_TONEAREST);

float m_ = std::fma( a, 0.5f, e );
float h = std::fma( m_, -2.0f, a ) + b;

float ret = std::fma( h, half_or_zero, m_ );

これでも微妙にダメな点がいくつかあります。 一つ目は、b±0x1.fffffep-126fの時にうまくいかないことです。 この問題は、eを求めるときに零への丸めを使えば回避できます。 二つ目は、hがオーバーフローして無限大になると0.0fをかけてもNaNになってしまうことです。 この問題も、hを求めるときに零への丸めを使えば回避できます。

以上でおそらくすべての問題が解決されます。

具体的なソースコード

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();

    // Half
    std::fesetround(FE_TOWARDZERO);
    float e = b * 0.5f;
    float f = std::fma( e, -2.0f, b );

    float half_or_zero = std::fma( f, f, -0x1.p+23f ) + 0x1.p+23f;

    // FastTwoSum(?)
    std::fesetround(rm);
    float m_ = std::fma( a, 0.5f, e );
    std::fesetround(FE_TOWARDZERO);
    float h = std::fma( m_, -2.0f, a ) + b;

    std::fesetround(rm);
    float ret = std::fma( h, half_or_zero, m_ );

    return ret;
}

8命令で実現することができました。

FastTwoSumをさかさまに使うと0.0fが返ってくることを利用してfを見ずにできないかといろいろ試しましたが、うまくいく方法を見つけることができませんでした。

検証

次の分岐ありコード(中点はどうやって計算するべきか - kashiの日記で紹介されている方法)と結果が一致するかを確認します。

float TrueMidpointWithIf( float a, float b ) {
    if( std::abs(a) > 1.0f && std::abs(b) > 1.0f ) {
        float ha = a * 0.5f; // a*0.5fは丸め誤差なし
        float hb = b * 0.5f; // b*0.5fは丸め誤差なし
        float ret = ha + hb; // std::fma(a, 0.5f, b*0.5f)で良さそうだけれど、文献通りにしておく
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    } else {
        float sum = a + b; // 非正規化数になるときは丸め誤差なし
        float ret = sum * 0.5f; // sumが正規化数ならば丸め誤差なし
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    }
}

実験に使ったソースコードは以下の通りです。

#include <cstdint>
#include <cstring>
#include <iostream>
#include <random>

template<class To, class From>
To bit_cast( From from ) {
    To to;
    std::memcpy( &to, &from, sizeof to );
    return to;
}

float getRandom( std::mt19937_64& mt ) {
        std::uint64_t s = mt();
        switch( s % 4 ) {
        case 0: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) ); // tiny float
        case 1: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) + 0x7e800000 ); // large float
        case 2: return bit_cast<float, std::uint32_t>( s >> 32 | 0x007ffff0 ); // large mantissa float
        case 3: return bit_cast<float, std::uint32_t>( s >> 32 ); // random float
        default: return 0.0f;
        }
}

int main( int argc, char* argv[] ) {
        static std::mt19937_64 mt(std::atoi(argv[1]));
        for( std::uint64_t i = 0; i < 100'000'000; ++i ) {
                float a = getRandom(mt);
                float b = getRandom(mt);
                if( a != a || b != b ) { continue; }
                if( bit_cast<std::uint32_t>(TrueMidpoint( a, b )) != bit_cast<std::uint32_t>(TrueMidpointWithIf( a, b )) ) {
                    std::cout << std::hexfloat << a << " " << b << " ";
                    std::cout << TrueMidpoint( a, b ) << " ";
                    std::cout << TrueMidpointWithIf( a, b ) << std::endl;
                }
        }
}

乱数のシード値としては1から16を使いました。 いずれの場合も出力がなかったため、TrueMidpointTrueMidpointWithIfと同じ動作をするようです。

ただし、このテストで確かめられていないコーナーケース入力について、以下の問題があることがわかっています。

  • TrueMidpoint( -0.0f, -0.0f )-0.0fを返してほしいが、0.0fを返してしまう
  • TrueMidpointの引数に無限大が含まれると、NaNが返ってきてしまう

これらは、two-component演算を使っている限り避けようがない問題です。

改良1

上で使った分岐を消す方法を応用して、TrueMidpointWithIf*1に含まれる分岐を消すことを考えます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = 0x1.p-149f / std::fma( b, 4.0f, 0x1.p-148f );
    std::fesetround(FE_TOWARDZERO);
    float one_or_half = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float half_or_one = 1.5f - one_or_half;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

このようにすると、除算という重たい演算が含まれてしまいますが、入力として無限大が来る場合や、入力の双方が-0.0fである場合も正しく取り扱うことができます。 無限大を普通の数に戻す方法は除算しかないので、この除算は不可避です。

改良2

入力として無限大が来ないならば、除算を取り除くことができます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = b * 0x1.p-149;
    std::fesetround(FE_TOWARDZERO);
    float half_or_one = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float one_or_half = 1.5f - half_or_one;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

7命令まで減りました。

まとめ

正しく中点を求める、分岐のないアルゴリズムを紹介しました。 命令数が多いので、実用的ではありません。

*1:正確にはfmaを使うように変更してbの絶対値の大小のみで分岐するようにしたバージョン

constexpr関数の評価速度

最近のC++では、constexprを使って手軽にコンパイル時計算ができます。 C++20ではstd::vectorstd::stringなどの動的なデータ構造もconstexpr指定されたため、さらに利便性が高まりました。 コンパイル時に多少の探索を行って最適な値を発見して埋め込む、といった利用方法も現実的となりました。

一方で、constexpr関数の評価は実質的にインタプリタ実行であり、速度が気になります。 今回はそれを調査してみました。

測定環境

  • Intel Core i9 12900K (Alder Lake-S)
    • 5.1 GHzくらいで動いている
  • Ubuntu 20.04.3 LTS on WSL2 (Windows 11)
  • g++ (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0(aptでインストールしたもの)
  • clang++ version 12.0.1(公式によるx86_64-unknown-linux-gnu向けプレビルドバイナリ)
  • 時間の測定は何回か実行した平均

実験に使ったソースコード

#include <iostream>

X unsigned long long f( unsigned long long n ) {
        unsigned long long x = 1234567890987654321;
        for( ; n--; ) {
                x ^= x << 7;
                x ^= x >> 9;
        }
        return x;
}

int main() {
        Y unsigned long long x = f( N );
        std::cout << x << std::endl;
}

X, Y, Nはコンパイルオプションで指定します。

g++

g++は、以下の戦略でコンパイル時計算を行うようです。

  • 関数がconstexpr指定されていなくて、変数がconstexpr指定されていなければ、コンパイル時計算を行わない。
  • 関数がconstexpr指定されていなくて、変数がconstexpr指定されていれば、コンパイルエラー(C++の仕様)。
  • 関数がconstexpr指定されていて、変数がconstexpr指定されていなければ、
    • とりあえずコンパイル時計算してみる。
    • 評価回数上限(-fconstexpr-loop-limit-fconstexpr-ops-limitで指定できる)に達したら、コンパイル時計算を打ち切る。
  • 関数がconstexpr指定されていて、変数がconstexpr指定されていなければ、
    • とりあえずコンパイル時計算してみる。
    • 評価回数上限(-fconstexpr-loop-limit-fconstexpr-ops-limitで指定できる)に達したら、コンパイルエラーにする。

実行時間は以下のようになりました。

X Y Opt limit N=1000000 N=10000000
なし なし -O0 default 0.16 s 0.15 s
なし なし -O2 default 0.17 s 0.16 s
constexpr なし -O0 default 0.49 s 0.48 s
constexpr なし -O2 default 1.08 s 1.15 s
constexpr なし -O0 999999999 1.89 s 23.3 s
constexpr なし -O2 999999999 1.97 s 23.6 s
constexpr constexpr -O0 999999999 1.94 s 22.5 s
constexpr constexpr -O2 999999999 1.92 s 23.3 s

普通の命令セットだと一周当たり4cycleくらいで実行可能な命令列にコンパイルされるはずなので、5GHzのCPUで動かすと1秒間に109周以上できるはずです。 107周に23秒かかっているので、コンパイルされたコードの2500倍程度遅いということがわかりました。

clang++

clang++は、変数がconstexpr指定されている時にのみコンパイル時計算を行うようです。

実行時間は以下のようになりました。

X Y Opt limit N=1000000 N=10000000 N=30000000 N=100000000
なし なし -O0 999999999 0.16 s 0.17 s
なし なし -O2 999999999 0.18 s 0.20 s
constexpr なし -O0 999999999 0.20 s 0.18 s
constexpr なし -O2 999999999 0.17 s 0.18 s
constexpr constexpr -O0 999999999 0.83 s 6.8 s
constexpr constexpr -O2 999999999 0.85 s 6.8 s 20 s 67 s

108周に67秒かかっているので、コンパイルされたコードの700倍程度遅いということがわかりました。

まとめ

コンパイル時計算(constexpr関数の評価)は機械語コードの実行と比べて三桁(オーダーで1000倍)程度遅いことがわかりました。

符号なし乗算器の作り方の勉強と11ビット乗算器の設計

 Nビット符号なし乗算器は、 Nビットの符号なし整数二つを受け取り、 2Nビットの符号なし整数を出力する回路です。

 Nビット乗算器は、部分積を作る N^2個のANDゲートと N(N-2)個の全加算器(full adder, FA)を組み合わせて作ることができます。 この構成で全加算器が N(N-2)個必要なのは、以下のような説明ができます。 全加算器は3入力2出力なので信号線を一本減らす効果があります。 部分積は全部で N^2個あり、これを 2N本に減らす必要があります。 よって、 N^2 - 2N = N(N-2)本減らすためには、 N(N-2)個の全加算器が必要です。

以下では、さらに効率的に乗算器を作るいくつかのテクニックを書き、その後11ビット乗算器を設計してみます。

テクニック

Wallace Tree

Wallace Tree(ワラス木、またはより原音に近くウォレス木)は、 N(N-2)個の全加算器をどのように配置するかの方法の一つです。 Wallace Treeでは、全加算器をバランスした木状に配置します。 順次加算していく方式ではレイテンシが \Theta(N)段となってしまいますが、バランスした木とすることでレイテンシを \Theta(\log N)段に抑えます。

実際には、本当に全加算器だけで作ってしまうと桁上げ伝搬加算器(ripple-carry adder)になってしまって、 \Omega(N)段のレイテンシが不可避です。 そこで、各位の信号線の本数が2本になるまではWallace Treeで減らしていって、その先は桁上げ先読み加算器などの高速加算器を使う、といった方法が使われます。 Wallace Treeで各位の信号線が2本になるまで減らすレイテンシが \Theta(\log N)段、高速加算器のレイテンシが \Theta(\log N)段、ということで全体として \Theta(\log N)段のレイテンシが達成できます。

Wallace Treeの模式図として、図1のようなものが紹介されることがあります。 しかし、これは「わかっている」技術者向けに書かれた、信号線の減り具合を示すための模式図です。 桁上げは次の位に送らなければいけませんから、実際の回路は図2のようになります*1

図1: 19入力2出力のWallace Tree(六段)の模式図

図2: 六段のWallace Treeの実際。橙色で示した線は下位からの桁上げ信号

Booth encoding

Booth encoding(ブース符号化)は、部分積を減らす方法の一つです。 Radix-2、Radix-4、Radix-8、……などがあります。

Radix-2 Booth encoding

Radix-2 Booth encodingは、乗数の二進表現に1が連続した部分がある場合、それが(シフトと)足し算一回と引き算一回で実現できるという事実を利用します。 普通の乗算器で31倍を実現するには16倍、8倍、4倍、2倍、1倍、の5つの数を足しこむ必要があります。 しかし、被乗数の32倍を足し、被乗数の1倍を引く、とすれば実現できます。 Radix-2 Booth encodingは、このような最適化を行うことです。 なぜRadix-2 Booth「encoding」というかというと、この手順は乗数を二進表現(各位が0か1)から冗長二進表現(各位が-1か0か1)に変換していることに相当するからです。

31倍のような特殊な数だけではなく、乗数が0b0011111011110のような1が連続している部分がありさえすれば、この手法を適用することで加算の回数を減らすことができます。 残念ながら、Radix-2 Booth encodingを用いても加算の回数を減らせない乗数がある(例えば0b010101010101とか)ので、ハードウェア実装には向きません。 ハードウェア実装では、最悪に備えて回路を用意しないといけないからです。 手回し式計算器を使ったり、乗算器のないCPUで計算したりする際などには役立つようです。 あるいは、可変レイテンシの直列乗算器を作るためにも使えそうです。

Radix-4 Booth encoding

Radix-4 Booth encodingはこの仕組みを昇華させたものであり、非常に効率的であるためハードウェア乗算器の実装でよく用いられるようです。 この方法では、乗数を二桁ごとに区切って、四進表現(各位が0か1か2か3)から冗長四進表現(各位が-2か-1か0か1か2)に変換します。 ここで、各位としてあり得る数に3がなくなっているのが重要です。 3倍は加算器を使わないと表現できませんが、-2倍、-1倍、0倍、1倍、2倍は加算器なしにシフトだけで簡単に作ることができます(※本当は符号反転に加算器が必要ですが解決する方法があります。後述します)。 これにより、部分積が乗数の二桁ごとに一つとなるため、加算すべき部分積を半分に減らすことができます。 代わりに-2倍・-1倍・0倍・1倍・2倍から選択する回路が必要になりますが、部分積の数が半分になることはWallace Tree約二段分(二段で9本を4本にすることができる)に相当するので、メリットが上回ります。

エンコーディングするためには、a[3:1], a[5:3], a[7:5], a[9:7], ...のように、二桁ずつずれていく重なりのある三桁を取り出す必要があります。 具体的には、a×bを計算するとき、

  • a[2k+1:2k-1]が000なら、bの0倍を出力する
  • a[2k+1:2k-1]が001か010なら、bの 1\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が011なら、bの 2\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が100なら、bの -2\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が101か110なら、bの -1\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が111なら、bの0倍を出力する

とします。

これに加えてa[0]由来の部分積を作る必要があります。 このためには、a[-1]=0だと思ってa[1:(-1)]について同様のエンコーディングを行えばよいです。 この項は-2倍か-1倍か0倍か1倍が出てきます。

また、最上位付近をどこまで足さないといけないかも注意が必要です。 と言ってもほぼ明らかで、最上位ビットa[N-1]の影響があるところまで足せばよいです。

 Nが奇数であれば、a[N]=0だと思ってa[N, N-2]の分まで足せばよいです。 この項は、0倍か1倍か2倍が出てきます。 つまり、負の数は出てきません。

 Nが偶数であれば、a[N+1]=a[N]=0だと思ってa[N+1, N-1]の分まで足せばよいです。 この項は、0倍か1倍が出てきます。 つまり普通のANDゲートで作られる部分積です。

Radix-8 Booth encoding

Radix-8 Booth encodingはさらにこの仕組みを発展させたもので、乗数を三桁ごとに区切って八進表現(各位が0, 1, 2, 3, 4, 5, 6, 7のいずれか)から冗長八進表現(各位が-4, -3, -2, -1, 0, 1, 2, 3, 4のいずれか)に変換します。 冗長八進表現だとシフトだけでは作れない3倍が必要ですが、これは一回作ってしまえば使いまわせるので、十分大きなビット幅の乗算器であれば加算すべき部分積を三分の一に減らせる効果が上回ります。 部分積を三分の一に減らせる効果がRadix-4で部分積を二分の一に減らしたのよりも上回る境界は、3倍を作るコストを考慮に入れても一応 N=6程度なのでほとんどの乗算器に適用可能に見えます。 しかし、3倍を計算するレイテンシは隠蔽できず、またRadix-4とRadix-8の差分は部分積の数が三分の二になること、つまりWallace Tree一段分の効果しかないのでレイテンシを短くする効果はありません。 極力トランジスタ数を減らしたいということではない限り、Radix-4で十分に思えます。

符号反転の方法

Booth encodingを行うと、符号なし乗算器を作っているのに負の数が出てきて面倒ですが、うまく取り扱う方法が知られています。 以下Radix-4を前提とした説明を行いますが、負の数を取り扱うアイデア自体はRadixによらず適用可能です。

さて、「-2倍、-1倍、0倍、1倍、2倍は加算器なしにシフトだけで簡単に作ることができる」と書きましたが、実際にはこれは正しくありません。 符号を反転させるためには、一の補数を取った後に1を足さないといけなくて、ここに加算器が必要だからです。 この問題は、最後に足す1を一種の部分積だとみなすことにより解決できます。

また、このままだと符号拡張のためのビットが部分積として残ってしまい、部分積が四分の三に減るにとどまります。 ここで、符号拡張で生じる部分積は、すべて0(例えば、0b000000)かすべて1(例えば、0b111111)である点に注目します。 あらかじめ0b111111という定数を足す回路を用意しておけば、すべて1の時はそのままでよく、すべて0の時はそれに1を足せばよいです( \overline{S}を足せばよいです)。 また、足す定数は N/2個(合計 N^2/4ビットくらい)出現しますが、事前にその和を計算しておけば、部分積として入力しなければいけない定数は一つ( Nビットくらい)にできます。 この定数は、Radix-4の時0b010101...0101011になります。 なお、この符号拡張をキャンセルする手法は、符号付き乗算器を設計するときにも役立ちます。

さて、この定数の最後は0b011となっていますが、この部分は符号ビットの残り( \overline{S_0})と足し合わせると0b \overline{S_0}S_0S_0になります。 このようにすることで、部分積が多い位ができてしまうことを回避できます。

これを説明したのが図3です。 最後のRadix-4 Booth乗算器は、図解としてよく見るものと一致しています。

図3: Booth encodingで現れる負の数への対処方法。一つ目:-bを~b+1と表現し、足すべき1も部分積とみなす。二つ目:符号拡張で出現するビットを定数+符号ビットの反転( \overline{S})のように表現する。三つ目:  \overline{S_0}と定数の加算を実行することで、部分積の数は変わらないが位あたりの最大部分積数を削減できる。

その他の高速乗算アルゴリズムの適用

相当大きな乗算器でない限り、Karatsuba法やToom–Cook乗算、フーリエ変換を使った乗算などは役に立ちません。 計算量がナイーブ法を下回るビット幅が最も小さいKaratsuba法を例にとって、なぜ役に立たないのかを説明します。

Karatsuba法は、乗算回数を減らす代わりに加減算の回数が増える方法です。 ここで、ハードウェア乗算器においては、1ビット×1ビットの乗算はANDゲート1つでできる一方、これで生じた信号線を1本減らすには5~6ゲートからなる全加算器が必要です。 つまり、ANDゲートより5倍以上大きな全加算器の数が重要であり、その観点からすると乗算と加算はほぼ同じコストです。 これがKaratsuba法がビット数が相当大きくない限り役に立たない直観的な理由です。 部分積が減って全加算器が減る量と前処理・後処理で全加算器が増えてしまう量の大小により、Karatsuba法を使用すべきかが決まります。

必要な全加算器の数を概算してみます。 Radix-4 Booth乗算器では、全加算器が \frac1 2 N^2個程度必要です。 ここに(再帰的にではなく一段階だけ*2)Karatsuba法を適用すると、前処理に 2N個程度、後処理に 6N個程度の全加算器が必要で、代わりに乗算部分の全加算器を \frac3 8 N^2個程度に減らせます。 これが釣り合うのは N=64の時です。 したがって、Radix-4 Booth encodingを使用していれば、64ビット乗算器まででKaratsuba法を使うメリットは全くありません。 同様の計算により、128ビット乗算器であってもRadix-16 Booth encodingを採用すれば十分で、Karatsuba法を採用するメリットはなさそうだということがわかります。 それ以上、例えば256ビット乗算器を作る際には、Karatsuba法が役立つかもしれません。

Toom-2.5はさらに加算が多いので、相当大きなビット数でないとメリットがありません。 Radix-16 Booth encodingを前提にすると、1024ビット乗算器あたりでKaratsuba法に追いつくようです。 Toom-3以降は係数に \frac1 6などが出てきてもう一度乗算(循環小数定数での乗算なので実質的には \Theta(\log N)回の Nビット加算)が必要なので、実用するのはほぼ不可能でしょう。 フーリエ変換などなおさらです。

Wallace Treeの実際の設計

基本

Wallace Treeを設計するには、図4のように各位に何本の信号線があるかを書いたExcelシートを使うと便利です。 この時、各位にある信号線を3(全加算器の入力数)で割って切り捨てた数だけ、その位に全加算器を配置します。 これを何段か繰り返して、すべての位で信号線の本数が2以下になったとき、Wallace Treeが終了します(図5)。

図4: Excelを使ったWallace Treeの設計(その1)

図5: Excelシートを使ったWallace Treeの設計(その2)

アンチパターン

ただし、これだけだと各位の信号線の本数が上位ビットから「2, 2, 2, ..., 2, 2, 3」などとなった時、桁上げ伝搬加算器ができてしまいます。 この時にのみ、半加算器(half adder, HA)を使います。 半加算器は、2入力2出力で信号線の本数が減らないので極力使いたくないですが、アンチパターンが発生した場合だけは使うことでWallace Treeの段数を大幅に削減することができます。 その設計を行ったのが図6です。

図6: 半加算器を使ったWallace Treeのアンチパターンの回避

図7のようにもっと前の段から半加算器を使うことで段数を削減できることがあるようですが(Dadda乗算器と関係している?)、よくわかっていません。

図7: 半加算器を技巧的に配置することで段数を減らした例

11ビット符号なし乗算器の設計

なぜ11ビットを選んだかというと、半精度浮動小数点数(IEEE754のbinary16)の仮数部が(暗黙の1を含めて)11ビットだからです。 基本的には、Wallace TreeとRadix-4 Booth encodingを組み合わせて設計します。 定数との加算を実行し、0b \overline{S_0}S_0S_0とするタイプで設計します。

すると、図8のようになります。

図8: 11ビット乗算器の設計

0b \overline{S_0}S_0S_0を使うタイプで設計したため、一つの位あたりの最大の信号線の本数は六本となり、うまく三段に収めることができました。

Radix-4 Booth乗算器を素直に作ると、各位に所属する部分積の個数の分布がきれいな山形にならずギザギザした感じになるので、半加算器を多く使う必要があるようです。 この問題を解決するため、Booth encoderを工夫し、先に一桁だけ桁上げを行ってしまうことを考えます。 すると、図9のようになります。

図9: 少しだけ工夫した11ビット乗算器の設計

各位に所属する部分積の個数の分布がきれいな山形になったためか、半加算器の数を7から4に減らすことができました。

さて、「先に一桁だけ桁上げを行う」という部分で回路が増えていたら半加算器の数が減っても意味がありません。 そこで、その部分の回路を実際に確認してみます。 先に一桁だけ桁上げを行うRadix-4 Booth encoderの-2倍/-1倍/0倍/1倍/2倍、を選択する回路は以下を出力すればよいです。

a[2:0] 出力   ... out[3] out[2] out[1]その2 out[1]その1 out[0]
000 0b ... 0 0 0 0 0
001 1b ... b[3] b[2] b[1] 0 b[0]
010 1b ... b[3] b[2] b[1] 0 b[0]
011 2b ... b[2] b[1] b[0] 0 0
100 -2b ... ~b[2] ~b[1] ~b[0] 1 0
101 -1b ... ~b[3] ~b[2] ~b[1] ~b[0] b[0]
110 -1b ... ~b[3] ~b[2] ~b[1] ~b[0] b[0]
111 0b ... 1 1 1 1 0

この回路は、図10のように実現することができます。

図10: 先に一桁だけ桁上げを行うRadix-4 Booth encoder

これに対して通常のRadix-4 Booth encoderは図11のような回路となります。

図11: 通常のRadix-4 Booth encoder

ほぼ回路は変わっていないため、先に一桁だけ桁上げを行うことで半加算器を減らすメリットを受けられることがわかりました。

Verilog HDL実装

ここまでの実装をVerilog HDLで記述しました。 この実装は、あり得る222通りの入力すべてに対して正しい答えを返すことをVerilatorにより確認しています。

module multiplier11(lhs, rhs, result);
    input  wire[10:0] lhs;
    input  wire[10:0] rhs;
    output wire[21:0] result;

    function[11:0] Radix_4_Booth_select_012;
        input [1:0] a;
        input[10:0] b;

        case( a )
        2'b000: Radix_4_Booth_select_012 = 12'b0;
        2'b001: Radix_4_Booth_select_012 = { 1'b0, b };
        2'b010: Radix_4_Booth_select_012 = { 1'b0, b };
        2'b011: Radix_4_Booth_select_012 = { b, 1'b0 };
        endcase
    endfunction

    function[14:0] Radix_4_Booth;
        input [2:0] a;
        input[10:0] b;
        Radix_4_Booth[14] =  1'b1;
        Radix_4_Booth[13] = ~a[2];
        Radix_4_Booth[12:1] = Radix_4_Booth_select_012( a[1:0] ^ {2{a[2]}}, b );
        if(a[2])
            Radix_4_Booth[12:2] = ~Radix_4_Booth[12:2];
        Radix_4_Booth[0] = a[2] & ~Radix_4_Booth[1];
    endfunction

    function[15:0] Radix_4_Booth_lsb;
        input [1:0] a;
        input[10:0] b;
        Radix_4_Booth_lsb[15] = ~a[1];
        Radix_4_Booth_lsb[14] =  a[1];
        Radix_4_Booth_lsb[13] =  a[1];
        Radix_4_Booth_lsb[12:1] = Radix_4_Booth_select_012( a[1:0] ^ { 1'b0, a[1] }, b );
        if(a[1])
            Radix_4_Booth_lsb[12:2] = ~Radix_4_Booth_lsb[12:2];
        Radix_4_Booth_lsb[0] = a[1] & ~Radix_4_Booth_lsb[1];
    endfunction

    function[11:0] Radix_4_Booth_odd_msb;
        input [1:0] a;
        input[10:0] b;
        Radix_4_Booth_odd_msb = Radix_4_Booth_select_012( a, b );
    endfunction

    wire[15:0] a0b  = Radix_4_Booth_lsb    (lhs[ 1:0], rhs);
    wire[14:0] a2b  = Radix_4_Booth        (lhs[ 3:1], rhs);
    wire[14:0] a4b  = Radix_4_Booth        (lhs[ 5:3], rhs);
    wire[14:0] a6b  = Radix_4_Booth        (lhs[ 7:5], rhs);
    wire[14:0] a8b  = Radix_4_Booth        (lhs[ 9:7], rhs);
    wire[11:0] a10b = Radix_4_Booth_odd_msb(lhs[10:9], rhs);

/*
    assign result =   { 7'b0, a0b[15:1]        } + { 20'b0, a0b[0], 1'b0 }
                    + { 6'b0, a2b[14:1],  2'b0 } + { 18'b0, a2b[0], 3'b0 }
                    + { 4'b0, a4b[14:1],  4'b0 } + { 16'b0, a4b[0], 5'b0 }
                    + { 2'b0, a6b[14:1],  6'b0 } + { 14'b0, a6b[0], 7'b0 }
                    + {       a8b[14:1],  8'b0 } + { 12'b0, a8b[0], 9'b0 }
                    + {      a10b      , 10'b0 } ;
*/

    function[1:0] full_adder;
        input a;
        input b;
        input c;
        full_adder = { a&b|b&c|c&a, a^b^c };
    endfunction

    wire[1:0] x3  = full_adder(a0b[ 4], a2b[ 2], a2b[ 0]);
    wire[1:0] x4  = full_adder(a0b[ 5], a2b[ 3], a4b[ 1]);
    wire[1:0] x5  = full_adder(a0b[ 6], a2b[ 4], a4b[ 2]);
    wire[1:0] x6  = full_adder(a0b[ 7], a2b[ 5], a4b[ 3]);
    wire[1:0] x7  = full_adder(a0b[ 8], a2b[ 6], a4b[ 4]);
    wire[1:0] x8  = full_adder(a0b[ 9], a2b[ 7], a4b[ 5]);
    wire[1:0] x9  = full_adder(a0b[10], a2b[ 8], a4b[ 6]);
    wire[1:0] x10 = full_adder(a0b[11], a2b[ 9], a4b[ 7]);
    wire[1:0] x11 = full_adder(a0b[12], a2b[10], a4b[ 8]);
    wire[1:0] x12 = full_adder(a0b[13], a2b[11], a4b[ 9]);
    wire[1:0] x13 = full_adder(a0b[14], a2b[12], a4b[10]);
    wire[1:0] x14 = full_adder(a0b[15], a2b[13], a4b[11]);
    wire[1:0] y9  = full_adder(a6b[ 4], a8b[ 2], a8b[ 0]);
    wire[1:0] y10 = full_adder(a6b[ 5], a8b[ 3], a10b[0]);
    wire[1:0] y11 = full_adder(a6b[ 6], a8b[ 4], a10b[1]);
    wire[1:0] y12 = full_adder(a6b[ 7], a8b[ 5], a10b[2]);
    wire[1:0] y13 = full_adder(a6b[ 8], a8b[ 6], a10b[3]);
    wire[1:0] y14 = full_adder(a6b[ 9], a8b[ 7], a10b[4]);
    wire[1:0] y15 = full_adder(a6b[10], a8b[ 8], a10b[5]);
    wire[1:0] y16 = full_adder(a6b[11], a8b[ 9], a10b[6]);
    wire[1:0] y17 = full_adder(a6b[12], a8b[10], a10b[7]);
    wire[1:0] y18 = full_adder(a6b[13], a8b[11], a10b[8]);
    wire[1:0] y19 = full_adder(a6b[14], a8b[12], a10b[9]);

    wire[1:0] z5  = full_adder( x5 [0], x4 [1], a4b[0]);
    wire[1:0] z6  = full_adder( x6 [0], x5 [1], a6b[1]);
    wire[1:0] z7  = full_adder( x7 [0], x6 [1], a6b[2]);
    wire[1:0] z8  = full_adder( x8 [0], x7 [1], a6b[3]);
    wire[1:0] z9  = full_adder( x9 [0], x8 [1], y9 [0]);
    wire[1:0] z10 = full_adder( x10[0], x9 [1], y10[0]);
    wire[1:0] z11 = full_adder( x11[0], x10[1], y11[0]);
    wire[1:0] z12 = full_adder( x12[0], x11[1], y12[0]);
    wire[1:0] z13 = full_adder( x13[0], x12[1], y13[0]);
    wire[1:0] z14 = full_adder( x14[0], x13[1], y14[0]);
    wire[1:0] h15 = full_adder(a2b[14], x14[1],  1'b0 ); // Half Adder
    wire[1:0] z15 = full_adder(a4b[12], y14[1], y15[0]);
    wire[1:0] z16 = full_adder(a4b[13], y15[1], y16[0]);
    wire[1:0] z17 = full_adder(a4b[14], y16[1], y17[0]);
    wire[1:0] z20 = full_adder(a8b[13], a10b[10], y19[1]);

    wire[1:0] w7  = full_adder(z7 [0], z6 [1], a6b[0]);
    wire[1:0] w8  = full_adder(z8 [0], z7 [1], a8b[1]);
    wire[1:0] w9  = full_adder(z9 [0], z8 [1],  1'b0 ); // Half Adder
    wire[1:0] w10 = full_adder(z10[0], z9 [1], y9 [1]);
    wire[1:0] w11 = full_adder(z11[0], z10[1], y10[1]);
    wire[1:0] w12 = full_adder(z12[0], z11[1], y11[1]);
    wire[1:0] w13 = full_adder(z13[0], z12[1], y12[1]);
    wire[1:0] w14 = full_adder(z14[0], z13[1], y13[1]);
    wire[1:0] w15 = full_adder(z15[0], z14[1], h15[0]);
    wire[1:0] w16 = full_adder(z16[0], z15[1], h15[1]);
    wire[1:0] w17 = full_adder(z17[0], z16[1],  1'b0 ); // Half Adder
    wire[1:0] w18 = full_adder(y18[0], z17[1], y17[1]);
    wire[1:0] w19 = full_adder(y19[0],  1'b0 , y18[1]); // Half Adder
    wire[1:0] w21 = full_adder(a8b[14], a10b[11], z20[1]);

    assign result =  { w21[0], w19, w17, w15, w13, w11, w9, w7, z5, x3, a0b[3:1] }
                   + { 1'b0, z20[0], w18, w16, w14, w12, w10, w8, 1'b0, z6[0], 1'b0, x4[0], 1'b0, a2b[1], a0b[0], 1'b0 };
endmodule

full_adderの入力としてどれを選ぶかは改善の余地があります。 今回は信号の名前がなるべくレギュラーになるように配置しましたが、実際は桁上げ信号の配線が短くなるように選ぶとよいでしょう。 また、複合ゲートが使える場合はa&b|b&c|c&aより~(a&b|b&c|c&a)の方が少ないトランジスタ数で作れるので、一部の全加算器は負論理としたほうが良いかもしれません。

符号ビットキャンセル定数伝搬最適化

符号ビットキャンセル定数となんらかの部分積を入力として持つ半加算器は定数伝搬によりほぼノーコスト(NOTゲート一つ)となります。 これを三か所に適用することで、三つの半加算器をなくせるようです。

図12: 符号ビットキャンセル定数を伝搬して半加算器を減らした例

小さな高速加算器を使う

単純に考えると最終段の高速加算器の幅は22ビットとなりますが、下位をWallace Tree実行中に桁上げ伝搬加算器で求めれば高速加算器の幅を縮めることができます。 下位のどこまでを確定させられるかを考えてみると、

  • 20の位は部分積が1つなので既に確定している
  • 21の位はWallace Treeの一段目で確定できる
  • 22の位はWallace Treeの二段目で確定できる
  • 23の位はWallace Treeの三段目で確定できる

となります。 また、最上位ビットは高速加算器のキャリー出力(これは高速加算器の結果が出る一段前に出力される)とのXORで作れるので、ここも高速加算器を通す必要がありません。

ここまでの結果から、高速加算器は24の位から220の位までを取り扱えばよいことがわかりました。 高速加算器の幅が17ビットとなり、微妙に二冪を超えているのでいまいちです。 試行錯誤してみたところ、1bit+2bit加算器と半加算器を適切な位置に配置すれば、24の位もWallace Treeの三段目で確定できることがわかりました。 1bit+2bit加算器は、図13に示すような4ゲートからなる3入力3出力の回路*3で、全加算器よりも小さいのでこれを使うことは問題ないはずです。 out[1]出力の遅延が全加算器一段分よりやや長い気もしますが、ここはクリティカルパスではないので問題ありません。

図13: 1bit+2bit加算器

これにより高速加算器の幅を16ビットとした設計が図14です。

図14: 最後の高速加算器を16ビットとしてみた例

まとめ

  • 乗算器を作るために有用な二つのテクニック
    • Wallace Tree
    • (Radix-4) Booth encoding
  • 実際に11ビット符号なし乗算器を設計した
    • 先に一桁だけ桁上げするRadix-4 Booth encoderを使うことで半加算器を減らせる
    • 符号ビットキャンセル定数を伝搬させることで半加算器を減らせる
    • 1bit+2bit加算器を導入することで、高速加算器の幅を16ビットにすることができた

*1:さらに全加算器の入力位置を取り換えて信号線が短くなる工夫をしてもよいでしょう。

*2:再帰の末端付近ではオーダーが悪いものの定数倍の小さな実装に切り替える」ということをやりたくて、その閾値を求めたいので再帰的に適用しないのが正しいです。

*3:というか半加算器二つを直列につないだものそのものです。