Yosysを使ってみる(その1)

Yosysというオープンソースの論理合成ツール(レジスタ転送レベルからゲートレベルに変換する一種のコンパイラ)があるので使ってみました。

github.com

インストール

公式のREADMEとだいたい同じです。Makefileを書き換えてインストール場所を変更すれば、インストールに特権は必要ありません。

sudo apt update
sudo apt-get install build-essential clang bison flex libreadline-dev gawk tcl-dev libffi-dev git graphviz xdot pkg-config python3 libboost-system-dev libboost-python-dev libboost-filesystem-dev zlib1g-dev
mkdir git
cd git
git clone https://github.com/YosysHQ/yosys
cd yosys/
vi Makefile
make -j
make install

ここで、Makefileは、PREFIX ?= /home/lpha/yosysのように書き換えました。 make -jは、12th Gen Intel(R) Core(TM) i9-12900Kで2分くらいかかりました。 99%になったくらいが折り返し地点です。

使ってみる

チュートリアル

Yosys Open SYnthesis Suite :: Screenshots を見ながら試してみます。 まず、counter.vcmos_cells.libをダウンロードして実験用ディレクトリに置いておきます。 その後、以下のように入力します(各行について、先頭から空白文字まではプロンプトなので入力しません)。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> dfflibmap -liberty cmos_cells.lib
yosys> abc -liberty cmos_cells_plus.lib

4bitの符号なし整数に1を足す回路は以下のように合成できるようです。

ABC RESULTS:              NAND cells:        7
ABC RESULTS:               NOT cells:        4
ABC RESULTS:               NOR cells:        9

ライブラリを書き換えてみる

cmos_cells.libをコピーしたcmos_cells_2.libに、以下の記述を書き加えてみます。

  cell(XOR) {
    area: 4;
    pin(A) { direction: input; }
    pin(B) { direction: input; }
    pin(Y) { direction: output;
             function: "(A*B+A'*B')"; }
  }

さきほどと同じ手順を踏むと、以下の結果を得ます。

ABC RESULTS:               NOT cells:        3
ABC RESULTS:              NAND cells:        3
ABC RESULTS:               XOR cells:        3
ABC RESULTS:               NOR cells:        6

NANDが4つ、NORが3つ、NOTが1つ、それぞれ減った代わりにXORが3つ増えたようです。 得したのかよくわかりません。

XORのコストを上げたら使われなくなるのかな、と思ってarea: 40;などとしてみたら、以下のようになりました。

ABC RESULTS:               NOT cells:        4
ABC RESULTS:              NAND cells:        6
ABC RESULTS:               XOR cells:        1
ABC RESULTS:               NOR cells:        7

XORの使用がひかえめになりました。 でも、XORはNAND四つで作れるので、NANDの10倍のコストに設定したのに使われるのは変です。 面積を最小化しているわけではないのでしょうか。

ゲートカウント

yosys gate countとかでググると、以下のページが出てきます。

github.com

以下のように入力すれば、CMOSゲートに合成してくれるようです。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> abc -g cmos

出力は以下のようになりました。

ABC RESULTS:               NOT cells:        2
ABC RESULTS:              NAND cells:        1
ABC RESULTS:              AOI3 cells:        1
ABC RESULTS:               NOR cells:        2
ABC RESULTS:               XOR cells:        2

AOI3のような、CMOSらしさあふれるものが使われていることがわかります(AOI3は複合ゲートの一種で、~((a&b)|c)が6トランジスタで作れるというやつです)。

トランジスタカウント

yosys transistor countとかでググると、以下のページが出てきます。

stackoverflow.com

以下のように入力すれば、トランジスタ数を算出してくれるようです。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> abc -g cmos
yosys> stat -tech cmos

すると、以下の出力を得ます。

12. Printing statistics.

=== counter ===

   Number of wires:                 19
   Number of wire bits:             31
   Number of public wires:           4
   Number of public wire bits:       7
   Number of memories:               0
   Number of memory bits:            0
   Number of processes:              0
   Number of cells:                 12
     $_AOI3_                         1
     $_NAND_                         1
     $_NOR_                          2
     $_NOT_                          2
     $_SDFFE_PP0P_                   4
     $_XOR_                          2

   Estimated number of transistors:         46+

+がついてしまっているのは、フリップフロップトランジスタ数がわからないことによるもののようです(stat.ccの215~218行目261行目)。 トランジスタ数は、cost.hの53行目~68行目に定義されているものが使われていそうです。 XORゲートは10トランジスタで作れるはず(XORゲート - Wikipedia)ですが、それ以外はあっていそうです。

なお、ソースコードのどこで定義されているかは、GitHub上でcmosと検索することで発見しました。

面積

dffmap libraryとかでググると、以下のページが出てきます。

tom01h.exblog.jp

どうやらOklahoma State Universityが180nm用のスタンダードセルライブラリを公開しているようです。 osu018_stdcells.lib(243KB)をダウンロードして、実験用ディレクトリに入れます。 これを、cmos_cells.libの代わりに使います。

$ ~/yosys/bin/yosys
yosys> read_verilog counter.v
yosys> hierarchy -check
yosys> proc; opt; fsm; opt; memory; opt
yosys> techmap; opt
yosys> dfflibmap -liberty osu018_stdcells.lib
yosys> abc -liberty osu018_stdcells.lib
yosys> stat -liberty osu018_stdcells.lib

出力は以下のようになりました。

12. Printing statistics.

=== counter ===

   Number of wires:                 34
   Number of wire bits:             46
   Number of public wires:           4
   Number of public wire bits:       7
   Number of memories:               0
   Number of memory bits:            0
   Number of processes:              0
   Number of cells:                 17
     AND2X1                          1
     AOI21X1                         3
     DFFPOSX1                        4
     INVX1                           2
     NAND3X1                         1
     NOR2X1                          3
     OAI21X1                         2
     XNOR2X1                         1

   Chip area for module '\counter': 754.000000

スタンダードセルの面積の確認

トランジスタを並べて回路を実現するとき、トランジスタを自由に配置できるとすると検証が大変です。 そこで、よくありそうな回路について検証されたトランジスタ配置を作っておき、それを並べて回路を実現することを考えます。 そのような検証されたトランジスタ配置がスタンダードセルです。 スタンダードセルは並べて使うことが前提なので、回路の外枠が長方形になっていて、かつ高さが固定値です。 高さが固定なので、トランジスタ数の多い回路は幅が広くなります。

osu018のスタンダードセルライブラリは、pMOSトランジスタとnMOSトランジスタのペア当たり、高さ8、幅1、となっていると思われます。

  • NOTゲート:2トランジスタ必要で1列、隣とのスペースに1列、で2列なので面積は16
  • NANDゲートやNORゲート:4トランジスタ必要で2列、隣とのスペースに1列、で3列なので面積は24
  • ANDゲートやORゲート:6トランジスタ必要で3列、隣とのスペースに1列、で4列なので面積は32
  • XORゲート:12トランジスタ必要で6列、隣とのスペースに1列、で7列なので面積は56(やっぱり12トランジスタで作るのが主流なのでしょうか)
  • 半加算器:XORとANDなので18トランジスタ必要で9列、隣とのスペースに1列、で10列なので面積は80
  • 全加算器:負論理で24トランジスタ必要※で12列、反転に4トランジスタ必要で2列、隣とのスペースに1列で15列なので面積は120

※参考文献:https://www.hindawi.com/journals/vlsi/2012/173079/XORゲート - Wikipediaの参考文献として載っていました)

三入力NANDゲートは同様の推論で行くと32になりそうなところ36になっていて謎ですが、他は大体あっていそうです。

次に読む

さきほどのOklahoma State Universityのスタンダードセルライブラリを紹介していたブログの他の記事で、遅延の評価をやっているので、今度はこれをやってみたいと思います。

tom01h.exblog.jp

AVX-512の機能を使ったlogf(x)の実装(その2)

前回(AVX-512の機能を使ったlogf(x)の実装(その1) - よーる)の記事で、AVX-512でベクトル実行することを前提としたlogfの高速実装を作りました。 速度を重視しつつもできる限り精度に気を付けた結果、ほとんどの入力に対して誤差を1.5ULP未満とできました。 ただし、対数関数の定義域に属する2139095039個の単精度浮動小数点数のうち、1502個についてだけは誤差が1.5ULPを超えてしまいました(それでも最大誤差は2.2ULP未満です)。 今回は誤差の原因を詳細に解析し、その誤差の要因を避けることのできる係数を発見することで、最大誤差1.5ULPを保証するlogfの高速実装を作ります。

前回の実装の確認

前回の実装の本体を示します。 補助関数まで示すと冗長になるので省略しました(vgetexpps浮動小数点数を仮想的に*1正規化して指数部を取り出す関数、vgetmantps浮動小数点数を仮想的に正規化して仮数部を取り出す関数、vpermpsは16エントリの表引きを行う関数です)。

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

ようするに、xを仮想的に正規化したときの指数部expo e)とxを仮想的に正規化したときの仮数部から決まる16通りの定数invs \sigma_k = s_k^{-1})を使って、 \log x = e\log2 + \log s_k + \log\frac x{2^es_k} のように計算しているということです。

誤差の要因の確認

x1.0fから十分離れている時

 \left|\log\frac x {2^es}\right| |\log x| より十分小さく、 \log\frac x  {2^es} の計算で生じる誤差が最終結果に与える影響は十分小さいです。 よって素直に計算すれば大きな誤差が発生することはありません(実際、1.5ULPを超える誤差が発生する1502ケースはいずれも  \exp(-2^{-6}) \lt x \lt \exp(2^{-5}) です)。

x1.0fに近いとき

expo0.0fとなります。よって、std::fma( expo, 0x1.62e430p-1f, logs )logsになります。 この時、最終結果の誤差は主に以下の四つの誤差の和からなります。

  1. 最後のstd::fma丸め誤差
  2. logs \log sの差
  3. t丸め誤差
  4. poly \log(1+t)/tの差のt

このうち、一つ目は明らかに避けることができません(0.5ULP)。 二つ目は、前回説明したように、 \sigma_k を適切に選ぶことで十分小さくできます。 三つめは、invs1.0fでない時は普通にやると避けることができません(0.5ULP)。

さて、残る「poly \log(1+t)/t の差のt倍」はどうでしょうか。

poly \log(1+t)/t の差」に最も効いてくるのはpolyの計算の最後の丸め誤差(0.5ULP)です。 この0.5ULPの誤差が厄介で、polyが1より小さければ  2^{-25} なのですが、polyが1より大きいと  2^{-24} になってしまいます。  2^{-25} の誤差のt倍は、t仮数部が2に近いときtに対して0.5ULPであり、 2^{-24} の誤差のt倍は、t仮数部が2に近いときtに対して1.0ULPです。 実際にはこれに加えて計算途中の丸め誤差の影響と近似誤差が追加されます。

 k = 0の時

 t = x - 1 になり、tを求める部分では丸め誤差が発生しません。 polyが1より大きくて丸め誤差が大きくなってしまうのは  x \lt 1 の時で、t仮数部が2に近ければlogf(x)の指数部とtの指数部は同じになります。 その時、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差が0ULP、三つ目の誤差が0ULP、四つ目の誤差が1.0ULP以上、で合わせて1.5ULP以上になります。 すべての最悪ケースを同時に引くことはなさそうですが、それでも1.5ULPを保証するのはかなり厳しそうに見えます。

 k \ne 0の時

logf(x)の指数部に比べてtの指数部が十分小さければ、「x1.0fから十分離れている時」と同様の論理で、素直に計算すれば大きな誤差が発生することがありません。

tの指数部がlogf(x)の指数部と同じになる場合を考えます。 これは例えば、 x = 0.984 付近の時に  \sigma_{15} = 32/31 t = x\sigma_{15} -1 = 1.00748\times2^{-6} \log x = -1.03228\times2^{-6} になるような場合のことです。 すると、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が0.5ULP以上、で合わせて1.5ULP以上になります。 tの指数部が大きいということはつまり多項式近似の端の方なので、近似誤差を含めて1.5ULP未満に抑えることは難しそうです。 さらに言えば、 x = 1.0315 付近の時などはもっとひどくて、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が1.0ULP以上、で合わせて2.0ULP以上になって、どうにもなりません(実際、2.0ULPを超える誤差が発生する5ケースは全てこれです)。

このような問題が起こるのは、 \sigma_k が切り替わる境界がlogfの指数部やtの指数部が切り替わる境界とずれているからです。 それをまとめたのが以下の表です。

本来あるべき境界 実際の境界 問題点
 \sigma_{15} \sigma_0 の境界  (1+2^{-6})s_{15}=0.983887  \frac{63}{64}=0.984375  k = 15 の負担が大
 \sigma_0 \sigma_1 の境界  \exp(2^{-5})=1.031743  \frac{33}{32}=1.03125  k = 1 の負担が大

このうち、 s_{15}がかかわるほうの問題は s_{15}を適切にとることで解消できます。

戦略

以下の二つの戦略により、上記問題を解決します。

  • idxfを求める部分で浮動小数点数加算ではなく浮動小数点数積和演算を用いることで、命令数を増やさずに境界をずらす
  • 多項式近似の係数を工夫して、 \log(1+t)/t \gt 1 となる範囲( t \lt 0)では極力近似誤差が0に近くなるようにする

さらに、上記戦略をうまく働かせるため、 k ごとに  t としてあり得る範囲を調整します。 これは  \sigma_k \left(1+\frac{k}{16}\right)^{-1} から大胆にずらすことで実現します。

終結果に与える最大誤差が以下のようになることを目指して最適化します。

誤差要因  k = 0 の場合  k = 15 または  k = 1 の場合
最後のstd::fma丸め誤差 0.5ULP 0.5ULP
logs \log sの差 0ULP ほぼ0ULP
t丸め誤差 0ULP 0.5ULP
poly \log(1+t)/tの差のt 1.0ULP 0.5ULP

idxfを求める部分の変更

 \sigma_0 \sigma_1 の境界が  \exp(2^{-5})=1.031743 より大きくなり、かつ  \sigma_{15} \sigma_0 の境界が  \exp(-2^{-6})=0.984496 以下となれば良いです。

実際は近似多項式の係数と同時に最適化するのですが、結論から言うと以下のように設定すればうまくいきました。

    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );

このようにすることで、 k = 1 かつ  t \lt 0 の場合で誤差が大きくなる問題を防ぎます。 代わりに  k = 0 の場合に取り扱わなければいけないtの範囲が  -0.016130 \lt t \le 0.03125 から  -0.011012 \lt t \lt 0.036084 になります。 負の側は範囲が狭くなり、正の側は範囲が広くなっています。 負の側は精度の要求が厳しく、正の側は精度の要求が緩いので、それと合致しています。 単に  k = 1 かつ  t \lt 0 の場合の回避だけでない効果をここに求めているようです。

 \sigma_k の選択

以下を満たしていれば意外と自由に取れます。

  •  \sigma_{15}
    • tの上限が  2^{-7} 以下になるようにする
      • 境界変更により  -2^{-6} \lt \log(x) \lt -2^{-7} となる付近を取り扱わないといけないので、  2^{-7} \lt t だとpoly \log(1+t)/t の差のt倍が最終結果に対して0.5ULP以上になってしまう
      • 多項式近似の区間の幅を制約するので、tの上限はなるべく  2^{-7} に近いほうが良い
  •  \sigma_1
    • tの上限が  k = 0 の場合のtの上限以下になるようにする
    • tの下限が  k = 15 の場合のtの下限以下になるようにする

多項式近似の工夫

上記  \sigma_k の選択の下、以下を守れば全体が1.5ULP保証になります。 実際には、ほんの少し守れていなくても、誤差要因全てがそろうことはないので、結果的に1.5ULP保証できる実装が手に入ります。

  • 多項式近似の誤差が  t=-2^{-N}(N=7,8,9,\dots) でほぼ0になるようにする
    • poly \log(1+t)/t の差のt倍」はt仮数部が2に近いときtのULPに対して大きくなるので、 t が二の冪乗の時が重要
    •  k = 0 において、t仮数部が2に近くて t \lt 0ならばlogf(x)のULPはtのULPと同じ
    • poly丸め誤差だけで最終結果に与える影響が1.0ULPあるので、多項式近似誤差が少しでもあると1.0ULPを超えてしまう
  • poly \log(1+t)/t の差のt倍について、
    • 区間の左端  t = L において  2^{-29} 未満になるようにする
      •  k = 0 では使われず、 k = 15 k = 1 で使われる
      • その時  2^{-5} \lt |\log x| \lt 2^{-4}なので、logf(x)に対しての0.5ULPは 2^{-29}
    • 区間の右端  t = R において  2^{-28} 未満になるようにする
      •  k = 0 で使われ、その時  2^{-5} \lt \log x \lt 2^{-4}なので、logf(x)に対しての1.0ULPは 2^{-28}
      •  k = 1 で使われ、その時  2^{-4} \lt \log x \lt 2^{-3}なので、logf(x)に対しての0.5ULPは 2^{-28}
        • ただし、 k = 1 の場合にはtの指数部がlogf(x)の指数部よりも一つ小さくてt丸め誤差が最終結果に与える影響が0.25ULPしかないので実際には問題とならない
    •  t \in [\exp(2^{-N-1})-1, \exp(2^{-N})-1)(N=5,6,7,\dots) において  2^{-N-24} 未満になるようにする
      •  k = 0 において重要で、その区間において  2^{-N-1} \le \log x \lt 2^{-N}なので、logf(x)に対しての1.0ULPは 2^{-N-24}

これを満たす近似多項式は、以下のようになります(限界までは最適化していません)。

    const float C4 = -0x1.fb4ef8p-3f;
    const float C3 =  0x1.556fc4p-2f;
    const float C2 = -0x1.ffffe2p-2f;

    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );

この時、tの値ごとの「t丸め誤差」と「poly \log(1+t)/t の差のt倍」が最終結果に与えるULPでみた影響の最大値は、図1のようになります。

図1: tの値ごとの最終丸め誤差以外の誤差が最終結果に与えるULPでみた影響の最大値のグラフ

なんとか±1.0ULPに収まっています。

最終的な実装のソースコード

// -------------------------------------------------------
//  Copyright 2023 lpha-z
//  https://lpha-z.hatenablog.com/entry/2023/09/03/231500
//  Distributed under the MIT license
//  https://opensource.org/license/mit/
// -------------------------------------------------------

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.0000000p+0f,
0x1.e286920p-1f,
0x1.c726fe0p-1f,
0x1.af35980p-1f,
0x1.99a95e0p-1f,
0x1.861a9e0p-1f,
0x1.746c640p-1f,
0x1.6435820p-1f,
0x1.5564f40p+0f,
0x1.47a8960p+0f,
0x1.3b1c5e0p+0f,
0x1.2f640a0p+0f,
0x1.24958c0p+0f,
0x1.1a813e0p+0f,
0x1.11180c0p+0f,
0x1.04d9b40p+0f,
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.e5b538p-5f,
+0x1.e2118ap-4f,
+0x1.5fb476p-3f,
+0x1.c8b0a8p-3f,
+0x1.166fecp-2f,
+0x1.45eeaap-2f,
+0x1.7383aap-2f,
-0x1.26c4fcp-2f,
-0x1.f96f70p-3f,
-0x1.a97736p-3f,
-0x1.5bd74ap-3f,
-0x1.118fbcp-3f,
-0x1.9387e8p-4f,
-0x1.08c23ep-4f,
-0x1.338588p-6f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );
    if( mant >= 0x1.79c328p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    const float C4 = -0x1.fb1370p-3f;
    const float C3 =  0x1.556f14p-2f;
    const float C2 = -0x1.ffffe2p-2f;
    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );

    return ret;
}

これらの改良により上記実装は先週示した実装と比べて命令数を増やさずに最大誤差を小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

となりました。  2 \le k \le 14 に対する  \sigma_k は適当にとったので、tの絶対値が小さくなるよう範囲の中心にとれば平均誤差はもう少し減らせるかもしれません。 ただ、0.5ULPを超える誤差が頻繁に発生する原因はstd::fma( expo, 0x1.62e430p-1f, logs )の丸めなので、係数を工夫しても0.5ULPを超える誤差が発生する頻度を下げることは難しそうです。 std::fma( expo, 0x1.62e430p-1f, std::fma( poly, t, logs ) )の順番で計算すれば0.5ULPを超える誤差が発生する頻度を3.3%まで下げることができますが、レイテンシが一命令分伸びます。

また、最大の誤差は-1.45944ULP/+1.47702ULPでした。 ULPで見た最大誤差を1.5ULPより有意に小さくすることは、命令数を増やさずには不可能だと思います。 アルゴリズムを大幅に見直せば回避できる可能性はありますが、おそらく命令数が増えます。

まとめ

logf(x)関数に含まれる誤差の原因を詳細に調べました。 idxfの計算に積和演算を使うことでテーブルの担当領域をずらし、さらに注意深く近似多項式を設計することにより、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。

おまけ:没になった戦略

 \sigma_1 として下位ビットに0が並ぶような浮動小数点数を選択すれば、tを計算するときの丸め誤差を0にすることができます。 しかし、 \sigma_1 として選べる浮動小数点数の選択肢が少なく、 \log s_1浮動小数点数に丸めるときの誤差を小さくすることができません。  \log s_1 の指数部はtの指数部より一つ大きいので、tを計算するときの丸め誤差0.5ULPよりも最終結果に与える影響を小さくしたければ、 \log s_1丸め誤差は0.25ULP以内に収めなければなりません。 これを満たすことができる \sigma_10x1.efp-1くらいしかなく、理想的な値0x1.e2p-1から離れすぎています。 よって、この方法は採用しませんでした。

*1:通常の浮動小数点数には指数部の制限があり非正規化数がありますが、その制限を無視して取り扱うということを意味します。

AVX-512の機能を使ったlogf(x)の実装(その1)

2023年6月に光成滋生さんがAVX-512特有の機能(通常の浮動小数点数演算以外の命令)を使用したlogf(x)のベクトル実装を公開しました(解説記事→AVX-512によるvpermpsを用いたlog(x)の実装)。 具体的には、指数部を取り出すvgetexpps命令、仮数部を取り出すvgetmantps命令、高速なテーブル参照を実現するvpermps命令の三つを使用しています。

これらの機能は非常に強力であり、正しく組み合わせることでlog関数を大幅に高速化することができます。 しかし、このfmathの実装は、AVX-512特有の機能を使わない部分に改善の余地がありそうです。 いろいろな部分を改善していたら、オリジナルとかなり異なるlogf(x)の実装が得られました。

fmathの実装のC++

fmathの実装はアセンブリコードを出力するPythonコードの形で公開されており、読み解くことが困難です(そもそもビルドする方法がわかりませんでした……)。 とりあえず読み解いた感じでは、以下のようなC++コードとほぼ等価な計算を行っているようです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数の翻案です(不正確かもしれません)。

float tbl1[16] = {
std::bit_cast<float>(0x3f783e10), // 32/33
std::bit_cast<float>(0x3f6a0ea1), // 32/35
std::bit_cast<float>(0x3f5d67c9), // 32/37
std::bit_cast<float>(0x3f520d21), // 32/39
std::bit_cast<float>(0x3f47ce0c), // 32/41
std::bit_cast<float>(0x3f3e82fa), // 32/43
std::bit_cast<float>(0x3f360b61), // 32/45
std::bit_cast<float>(0x3f2e4c41), // 32/47
std::bit_cast<float>(0x3f272f05), // 32/49
std::bit_cast<float>(0x3f20a0a1), // 32/51
std::bit_cast<float>(0x3f1a90e8), // 32/53
std::bit_cast<float>(0x3f14f209), // 32/55
std::bit_cast<float>(0x3f0fb824), // 32/57
std::bit_cast<float>(0x3f0ad8f3), // 32/59
std::bit_cast<float>(0x3f064b8a), // 32/61
std::bit_cast<float>(0x3f020821), // 32/63
};
float tbl2[16] = {
std::bit_cast<float>(0xbcfc14c8), // log(32/33)ではなくlog(tbl1[0])
std::bit_cast<float>(0xbdb78694), // 以下同様
std::bit_cast<float>(0xbe14aa96),
std::bit_cast<float>(0xbe4a92d4),
std::bit_cast<float>(0xbe7dc8c6),
std::bit_cast<float>(0xbe974716),
std::bit_cast<float>(0xbeae8ded),
std::bit_cast<float>(0xbec4d19d),
std::bit_cast<float>(0xbeda27bd),
std::bit_cast<float>(0xbeeea34f),
std::bit_cast<float>(0xbf012a95),
std::bit_cast<float>(0xbf0aa61f),
std::bit_cast<float>(0xbf13caf0),
std::bit_cast<float>(0xbf1c9f07),
std::bit_cast<float>(0xbf2527c4),
std::bit_cast<float>(0xbf2d6a01),
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fmath_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float xm1  = x - 1.0f;
    float abs  = std::bit_cast<float>(std::bit_cast<std::uint32_t>(xm1) & 0x7fffffff);
    if( abs < 0.02f ) {
        c = xm1;
        z = 0.0f;
    }

    float ret  = std::fma( std::fma( std::fma( std::fma( -0.25008487f, c, 0.3333955701f ), c, -0.49999999f ), c, 1.0f ), c, z );
    return ret;
}

改良

事前計算点の改良

この実装では -\log\left(1+\frac{k+0.5}{16}\right) (0\le k\lt16)を事前計算したテーブルを用意しています。 0.5ずれた点の値を持っているのはidxfを切り捨てで作っているのに合わせているからです。 しかしその代償としてx1.0f付近(logf(x)0.0f付近になって精度の要求が厳しくなる所)の時に精度が足りなくなっていて、それを解決するために分岐が発生しています。 幸い、アルゴリズムを大きく変える必要があるわけではないので、マスク付き演算で実現できますが、追加コストとして五命令かかっています。

この問題は、素直に \log\left(1+\frac{k}{16}\right) (0\le k\lt16)を事前計算したテーブルを持つことで解決可能です。 そのためにはidexfを最近接丸めで作る必要がありますが、これは「すごく大きな数を足すと仮数部が丸められる」というテクニック(これも光成滋生さんの解説で知りました→高速な倍精度指数関数expの実装)を使えばよいです。 具体的には、以下のようにします。

    float mant = vgetmantps( x ); // 1.0f <= mant < 2.0f で出てくる
    float idxf = mant + 0x1.p+19; // 最近接の1/16の倍数に丸める
    if( mant >= 0x1.f8p+0f ) { // 0x1.f8p+0f <= mant < 2.0f の時
            mant *= 0.5f; // mant が b = 1.0 に近くなるように mant をずらす
            expo += 1.0f; // expo もずらす
    }

二命令減らし、精度のための追加コストを三命令にすることができました。 ただし、このif文を外すと完全におかしな値を計算することになるので、「精度を犠牲にしても高速化したいときはこの範囲を消してください」というコードにはなっていません。

テーブル値の精度を上げる

上で行ったexpomantのずらしは、もっと積極的にやってよいです。 境界を0x1.f8p+0fではなく0x1.78p+0fとして、 8\le k \lt 15用の値を \log\left(1+\frac{k}{16}\right)ではなく \log\left(\frac12+\frac{k}{32}\right)にします。 これによりlogbの絶対値が小さくなって精度が上がるので、全体の精度が向上します。

近似多項式の改良

使われている係数が謎です。 見た感じでは |c| \lt 0.019に対する最良近似多項式っぽいですが、範囲をそのようにする理由がわかりません(範囲は |c| \lt \frac1{33} = 0.030303のはずです)。 何か間違いがあるように思えます。

上の最適化を行った場合のcの範囲である -\frac{0.5}{17} \lt c \lt \frac{0.5}{16}を前提にExcelを使って最良近似多項式を求めたところ、四次の係数は-0x1.fe9d24p-3f、三次の係数は0x1.557ee2p-2f、二次の係数は-0x1.00000cp-1f、とすれば良さそうだということがわかりました。 当該範囲での近似誤差は、絶対誤差 6.8\times10^{-10} = 1.46\times2^{-31}未満です。

テーブル値の精度が高くなるように事前計算点を選ぶ

よく考えると、事前計算に用いる点 \sigma_k \left(1+\frac{k}{16}\right)^{-1}浮動小数点数に丸めたものでなくてもよいはずです。 あまりにも離れた値を使うと |c|が大きくなって多項式近似誤差が大きくなってしまいますが、多少のずれは許容されます。 そこで、 \sigma_kとして \left(1+\frac{k}{16}\right)^{-1}にそれなりに近い浮動小数点数のうち \log \sigma_k浮動小数点数に丸めたときに誤差が十分小さくなるものを採用しましょう。

精度的に重要なのは、 \sigma_{15}の修正のようです。 ここで、 \sigma_{15}として適当な値は \frac{32}{31}付近には全くありません。 これは \left.\frac{d}{dx}\log(x)\right|_{x=\frac{32}{31}}=\frac{31}{32}でありこの値が二進浮動小数点数で正確に表せる数であることによるもので、この付近では隣の浮動小数点数におけるlogの返り値も下位ビット(=丸め誤差)がほぼ同じになります。 そのため、 \sigma_{15}0x1.084210p+0から0x1.084550p+0へとかなり大きくずらす必要がありました。

なお、この考え方は元のソースコードの精度を上げるif文がないバージョン(0.5ずれた点でテーブルを作るので命令数が少ない代わりに精度が低い方法)にも適用可能です。 tbl1[0] = 0x1.f081dcp-1ftbl2[0] = -0x1.f76c56p-6f(真の値は-0x1.f76c55ffa3p-6付近なので丸め誤差は0.00071ULP)を使うことで |\log(x)| \lt 2^{-6}での最大絶対誤差を 3.4\times10^{-9} = 1.79 \times2^{-29}未満にすることができます。

改良版のソースコード

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.000000p+0f, // 16/16
0x1.e1e1e2p-1f, // 16/17
0x1.c71c72p-1f, // 16/18
0x1.af286cp-1f, // 16/19
0x1.99999ap-1f, // 16/20
0x1.861862p-1f, // 16/21
0x1.745d18p-1f, // 16/22
0x1.642c86p-1f, // 16/23
0x1.555556p+0f, // 32/24
0x1.47ae14p+0f, // 32/25
0x1.3b13b2p+0f, // 32/26
0x1.2f684cp+0f, // 32/27
0x1.24924ap+0f, // 32/28
0x1.1a7b96p+0f, // 32/29
0x1.111112p+0f, // 32/30
0x1.084550p+0f, // ~32/31
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.f0a30ap-5f,
+0x1.e27074p-4f,
+0x1.5ff306p-3f,
+0x1.c8ff7ap-3f,
+0x1.1675cap-2f,
+0x1.4618bap-2f,
+0x1.739d7ep-2f,
-0x1.269624p-2f,
-0x1.f991c4p-3f,
-0x1.a93ed8p-3f,
-0x1.5bf408p-3f,
-0x1.1178eep-3f,
-0x1.9335e4p-4f,
-0x1.08599ap-4f,
-0x1.047a88p-5f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

これらの改良により上記実装はfmathの実装よりも二命令減らしつつ最大誤差も小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

であり、最大の誤差は-1.925ULP/+2.122ULPでした。

この結果を見ると、ほとんどの場合は真値に最近接の浮動小数点数かその前後の浮動小数点数以外を返しており、それ以外の浮動小数点数を返す(≒1.5ULPを超える誤差が発生する)ことはほとんどないことがわかります。 誤差の原因を詳細に調べて各所の係数を工夫したところ、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。 ただ、その説明は長くなるので、今回の記事はここまでにします。

おまけ:速度重視版のソースコードの改良版

 |\log(x)| \lt 2^{-6}での絶対誤差は 3.4\times10^{-9} = 1.79 \times2^{-29}未満となります。 その外側では-2.289ULP/+1.157ULPです。 まともに最適化していないので、誤差はもう少し小さくできそうです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数をC++に翻案したものをベースとし、
// テーブル値と多項式の係数を変更することで精度を改善したものです。

float tbl1[16] = {
0x1.f081dcp-1f,
0x1.d42b36p-1f,
0x1.badefep-1f,
0x1.a43a42p-1f,
0x1.8fa128p-1f,
0x1.7c91dcp-1f,
0x1.6bb3b6p-1f,
0x1.5c7e0ap-1f,
0x1.4e9076p-1f,
0x1.40c2dep-1f,
0x1.34bd30p-1f,
0x1.29e50ap-1f,
0x1.1f3f34p-1f,
0x1.16067ap-1f,
0x1.0c996cp-1f,
0x1.048122p-1f,
};
float tbl2[16] = {
-0x1.f76c56p-6f,
-0x1.6e9312p-4f,
-0x1.290ddap-3f,
-0x1.9489aep-3f,
-0x1.fb779ap-3f,
-0x1.2fc65cp-2f,
-0x1.5e3292p-2f,
-0x1.89f0fep-2f,
-0x1.b3b51ap-2f,
-0x1.ded9ccp-2f,
-0x1.02fbeep-1f,
-0x1.154a94p-1f,
-0x1.27ed54p-1f,
-0x1.38a234p-1f,
-0x1.4a4b10p-1f,
-0x1.59f5fap-1f,
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fast_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float ret  = std::fma( std::fma( std::fma( std::fma( -0x1.0049acp-2f, c, 0x1.557f3ep-2f ), c, -0x1.fffff8p-2f ), c, 1.0f ), c, z );
    return ret;
}

中点の正しい計算方法

二つの浮動小数点数 a, bについて、中点の数学的な値は \frac{a+b}2です。 以下の記事に触発されたので、この中点を浮動小数点数に正しく丸めた値を得る方法について考えます。

中点はどうやって計算するべきか - kashiの日記

当該記事によれば、この問題は意外と難しく、分岐なしに浮動小数点数演算だけで正しく求める方法はまだ知られていないようです。

何が難しいのか

足したらオーバーフローする

まず、もっとも単純に思いつくのは(a+b)*0.5と計算する方法でしょう。 しかしこれは、a+bの時点でオーバーフローする可能性があるので、中点を正しく求めることができません。

引いてもオーバーフローする

高精度に計算する方法として、a + (b-a)*0.5という方法が紹介されることもあります。 しかし、今度はb-aがオーバーフローする可能性があります。 それは、baの符号が異なり、それぞれ絶対値がすごく大きい場合です。 まぁそんな二数の中点を正確に求めたいということはなさそうですが、正しいアルゴリズムではないという点が気になります。

アンダーフローした際の丸め誤差

では、a*0.5 + b*0.5a + (b*0.5 - a*0.5)とするのはどうでしょうか。 これらはオーバーフローの可能性はありませんが、今度は逆にアンダーフローの可能性が出てきます。 通常、二進浮動小数点数の0.5倍は丸め誤差なしに正確に表せますが、結果がアンダーフローする場合は丸め誤差が生じる可能性があります。 最終結果に与える影響が同じ方向の丸め誤差が二つ以上の演算で発生した場合、正しくない答えが求まることがあります。

正しく求める方法

以下では、実装を検証しやすいようにfloatを使います(double用のmidpoint関数は入力のパターンが 2^{128}通りあって検証が困難です)。 floatの特徴はほぼ使っていないので、簡単にdoubleに置き換えることができます。

作戦

std::fma( a, 0.5f, b*0.5f )がほとんどの場合に正しい値を返すことに注目します。 これが正しくない丸めを行う必要条件は、b*0.5fの計算で丸め誤差が発生することです。 この丸め誤差に由来する真の中点とstd::fma( a, 0.5f, b*0.5f )の差の二倍hを算出することを考えます。

使用テクニック

double-doubleのように、実数を二つの浮動小数点数 x, yの組で表現することを考えます。 しかし、double-doubleのように x + yを表現するのではなく、 x + 0.5yを表現することを考えます。

この表現を使うと、任意の浮動小数点数 aについて、 0.5aが表現できます。  x=0, y=aとすればよいのでこれ自体は自明です。 重要なのは、以下のTwoProduct類似の手順で正規化することができることです。

std::pair<float, float> Half( float a ) {
    float hi = a * 0.5f;
    float lo = std::fma( hi, -2.0f, a );
    return { hi, lo };
}

この時、lo0.0fFLT_TRUE_MIN-FLT_TRUE_MINになります(丸めモードが何であってもそうなります)。

このように0.5倍を無誤差で扱う方法を導入します。

概略

まず、先のHalf関数を用いて、次の e, fを求めます。

auto [e, f] = Half(b)

つづいて、仮の中点値 m'を求めます。

float m_ = std::fma(a, 0.5f, e)

 m'と真の中点の差の二倍 hを計算します。

float h = f ? std::fma( m_, -2.0f, a ) + b : 0.0f;

最後に、 m'にそれを足し合わせ、最終結果を求めます。

float ret = std::fma( h, 0.5f, m_ );

もちろん、これでは分岐が発生しているのでダメです。 しかし、f0.0fFLT_TRUE_MIN-FLT_TRUE_MINしかありえないことを思い出します。 これを利用すると、以下のようにして分岐を消すことができます。

std::fesetround(FE_TOWARDZERO);
float half_or_zero = std::fma( f, f, -FLT_MIN/FLT_TRUE_MIN ) + FLT_MIN/FLT_TRUE_MIN;
std::fesetround(FE_TONEAREST);

float m_ = std::fma( a, 0.5f, e );
float h = std::fma( m_, -2.0f, a ) + b;

float ret = std::fma( h, half_or_zero, m_ );

これでも微妙にダメな点がいくつかあります。 一つ目は、b±0x1.fffffep-126fの時にうまくいかないことです。 この問題は、eを求めるときに零への丸めを使えば回避できます。 二つ目は、hがオーバーフローして無限大になると0.0fをかけてもNaNになってしまうことです。 この問題も、hを求めるときに零への丸めを使えば回避できます。

以上でおそらくすべての問題が解決されます。

具体的なソースコード

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();

    // Half
    std::fesetround(FE_TOWARDZERO);
    float e = b * 0.5f;
    float f = std::fma( e, -2.0f, b );

    float half_or_zero = std::fma( f, f, -0x1.p+23f ) + 0x1.p+23f;

    // FastTwoSum(?)
    std::fesetround(rm);
    float m_ = std::fma( a, 0.5f, e );
    std::fesetround(FE_TOWARDZERO);
    float h = std::fma( m_, -2.0f, a ) + b;

    std::fesetround(rm);
    float ret = std::fma( h, half_or_zero, m_ );

    return ret;
}

8命令で実現することができました。

FastTwoSumをさかさまに使うと0.0fが返ってくることを利用してfを見ずにできないかといろいろ試しましたが、うまくいく方法を見つけることができませんでした。

検証

次の分岐ありコード(中点はどうやって計算するべきか - kashiの日記で紹介されている方法)と結果が一致するかを確認します。

float TrueMidpointWithIf( float a, float b ) {
    if( std::abs(a) > 1.0f && std::abs(b) > 1.0f ) {
        float ha = a * 0.5f; // a*0.5fは丸め誤差なし
        float hb = b * 0.5f; // b*0.5fは丸め誤差なし
        float ret = ha + hb; // std::fma(a, 0.5f, b*0.5f)で良さそうだけれど、文献通りにしておく
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    } else {
        float sum = a + b; // 非正規化数になるときは丸め誤差なし
        float ret = sum * 0.5f; // sumが正規化数ならば丸め誤差なし
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    }
}

実験に使ったソースコードは以下の通りです。

#include <cstdint>
#include <cstring>
#include <iostream>
#include <random>

template<class To, class From>
To bit_cast( From from ) {
    To to;
    std::memcpy( &to, &from, sizeof to );
    return to;
}

float getRandom( std::mt19937_64& mt ) {
        std::uint64_t s = mt();
        switch( s % 4 ) {
        case 0: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) ); // tiny float
        case 1: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) + 0x7e800000 ); // large float
        case 2: return bit_cast<float, std::uint32_t>( s >> 32 | 0x007ffff0 ); // large mantissa float
        case 3: return bit_cast<float, std::uint32_t>( s >> 32 ); // random float
        default: return 0.0f;
        }
}

int main( int argc, char* argv[] ) {
        static std::mt19937_64 mt(std::atoi(argv[1]));
        for( std::uint64_t i = 0; i < 100'000'000; ++i ) {
                float a = getRandom(mt);
                float b = getRandom(mt);
                if( a != a || b != b ) { continue; }
                if( bit_cast<std::uint32_t>(TrueMidpoint( a, b )) != bit_cast<std::uint32_t>(TrueMidpointWithIf( a, b )) ) {
                    std::cout << std::hexfloat << a << " " << b << " ";
                    std::cout << TrueMidpoint( a, b ) << " ";
                    std::cout << TrueMidpointWithIf( a, b ) << std::endl;
                }
        }
}

乱数のシード値としては1から16を使いました。 いずれの場合も出力がなかったため、TrueMidpointTrueMidpointWithIfと同じ動作をするようです。

ただし、このテストで確かめられていないコーナーケース入力について、以下の問題があることがわかっています。

  • TrueMidpoint( -0.0f, -0.0f )-0.0fを返してほしいが、0.0fを返してしまう
  • TrueMidpointの引数に無限大が含まれると、NaNが返ってきてしまう

これらは、two-component演算を使っている限り避けようがない問題です。

改良1

上で使った分岐を消す方法を応用して、TrueMidpointWithIf*1に含まれる分岐を消すことを考えます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = 0x1.p-149f / std::fma( b, 4.0f, 0x1.p-148f );
    std::fesetround(FE_TOWARDZERO);
    float one_or_half = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float half_or_one = 1.5f - one_or_half;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

このようにすると、除算という重たい演算が含まれてしまいますが、入力として無限大が来る場合や、入力の双方が-0.0fである場合も正しく取り扱うことができます。 無限大を普通の数に戻す方法は除算しかないので、この除算は不可避です。

改良2

入力として無限大が来ないならば、除算を取り除くことができます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = b * 0x1.p-149;
    std::fesetround(FE_TOWARDZERO);
    float half_or_one = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float one_or_half = 1.5f - half_or_one;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

7命令まで減りました。

まとめ

正しく中点を求める、分岐のないアルゴリズムを紹介しました。 命令数が多いので、実用的ではありません。

*1:正確にはfmaを使うように変更してbの絶対値の大小のみで分岐するようにしたバージョン

constexpr関数の評価速度

最近のC++では、constexprを使って手軽にコンパイル時計算ができます。 C++20ではstd::vectorstd::stringなどの動的なデータ構造もconstexpr指定されたため、さらに利便性が高まりました。 コンパイル時に多少の探索を行って最適な値を発見して埋め込む、といった利用方法も現実的となりました。

一方で、constexpr関数の評価は実質的にインタプリタ実行であり、速度が気になります。 今回はそれを調査してみました。

測定環境

  • Intel Core i9 12900K (Alder Lake-S)
    • 5.1 GHzくらいで動いている
  • Ubuntu 20.04.3 LTS on WSL2 (Windows 11)
  • g++ (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0(aptでインストールしたもの)
  • clang++ version 12.0.1(公式によるx86_64-unknown-linux-gnu向けプレビルドバイナリ)
  • 時間の測定は何回か実行した平均

実験に使ったソースコード

#include <iostream>

X unsigned long long f( unsigned long long n ) {
        unsigned long long x = 1234567890987654321;
        for( ; n--; ) {
                x ^= x << 7;
                x ^= x >> 9;
        }
        return x;
}

int main() {
        Y unsigned long long x = f( N );
        std::cout << x << std::endl;
}

X, Y, Nはコンパイルオプションで指定します。

g++

g++は、以下の戦略でコンパイル時計算を行うようです。

  • 関数がconstexpr指定されていなくて、変数がconstexpr指定されていなければ、コンパイル時計算を行わない。
  • 関数がconstexpr指定されていなくて、変数がconstexpr指定されていれば、コンパイルエラー(C++の仕様)。
  • 関数がconstexpr指定されていて、変数がconstexpr指定されていなければ、
    • とりあえずコンパイル時計算してみる。
    • 評価回数上限(-fconstexpr-loop-limit-fconstexpr-ops-limitで指定できる)に達したら、コンパイル時計算を打ち切る。
  • 関数がconstexpr指定されていて、変数がconstexpr指定されていなければ、
    • とりあえずコンパイル時計算してみる。
    • 評価回数上限(-fconstexpr-loop-limit-fconstexpr-ops-limitで指定できる)に達したら、コンパイルエラーにする。

実行時間は以下のようになりました。

X Y Opt limit N=1000000 N=10000000
なし なし -O0 default 0.16 s 0.15 s
なし なし -O2 default 0.17 s 0.16 s
constexpr なし -O0 default 0.49 s 0.48 s
constexpr なし -O2 default 1.08 s 1.15 s
constexpr なし -O0 999999999 1.89 s 23.3 s
constexpr なし -O2 999999999 1.97 s 23.6 s
constexpr constexpr -O0 999999999 1.94 s 22.5 s
constexpr constexpr -O2 999999999 1.92 s 23.3 s

普通の命令セットだと一周当たり4cycleくらいで実行可能な命令列にコンパイルされるはずなので、5GHzのCPUで動かすと1秒間に109周以上できるはずです。 107周に23秒かかっているので、コンパイルされたコードの2500倍程度遅いということがわかりました。

clang++

clang++は、変数がconstexpr指定されている時にのみコンパイル時計算を行うようです。

実行時間は以下のようになりました。

X Y Opt limit N=1000000 N=10000000 N=30000000 N=100000000
なし なし -O0 999999999 0.16 s 0.17 s
なし なし -O2 999999999 0.18 s 0.20 s
constexpr なし -O0 999999999 0.20 s 0.18 s
constexpr なし -O2 999999999 0.17 s 0.18 s
constexpr constexpr -O0 999999999 0.83 s 6.8 s
constexpr constexpr -O2 999999999 0.85 s 6.8 s 20 s 67 s

108周に67秒かかっているので、コンパイルされたコードの700倍程度遅いということがわかりました。

まとめ

コンパイル時計算(constexpr関数の評価)は機械語コードの実行と比べて三桁(オーダーで1000倍)程度遅いことがわかりました。

符号なし乗算器の作り方の勉強と11ビット乗算器の設計

 Nビット符号なし乗算器は、 Nビットの符号なし整数二つを受け取り、 2Nビットの符号なし整数を出力する回路です。

 Nビット乗算器は、部分積を作る N^2個のANDゲートと N(N-2)個の全加算器(full adder, FA)を組み合わせて作ることができます。 この構成で全加算器が N(N-2)個必要なのは、以下のような説明ができます。 全加算器は3入力2出力なので信号線を一本減らす効果があります。 部分積は全部で N^2個あり、これを 2N本に減らす必要があります。 よって、 N^2 - 2N = N(N-2)本減らすためには、 N(N-2)個の全加算器が必要です。

以下では、さらに効率的に乗算器を作るいくつかのテクニックを書き、その後11ビット乗算器を設計してみます。

テクニック

Wallace Tree

Wallace Tree(ワラス木、またはより原音に近くウォレス木)は、 N(N-2)個の全加算器をどのように配置するかの方法の一つです。 Wallace Treeでは、全加算器をバランスした木状に配置します。 順次加算していく方式ではレイテンシが \Theta(N)段となってしまいますが、バランスした木とすることでレイテンシを \Theta(\log N)段に抑えます。

実際には、本当に全加算器だけで作ってしまうと桁上げ伝搬加算器(ripple-carry adder)になってしまって、 \Omega(N)段のレイテンシが不可避です。 そこで、各位の信号線の本数が2本になるまではWallace Treeで減らしていって、その先は桁上げ先読み加算器などの高速加算器を使う、といった方法が使われます。 Wallace Treeで各位の信号線が2本になるまで減らすレイテンシが \Theta(\log N)段、高速加算器のレイテンシが \Theta(\log N)段、ということで全体として \Theta(\log N)段のレイテンシが達成できます。

Wallace Treeの模式図として、図1のようなものが紹介されることがあります。 しかし、これは「わかっている」技術者向けに書かれた、信号線の減り具合を示すための模式図です。 桁上げは次の位に送らなければいけませんから、実際の回路は図2のようになります*1

図1: 19入力2出力のWallace Tree(六段)の模式図

図2: 六段のWallace Treeの実際。橙色で示した線は下位からの桁上げ信号

Booth encoding

Booth encoding(ブース符号化)は、部分積を減らす方法の一つです。 Radix-2、Radix-4、Radix-8、……などがあります。

Radix-2 Booth encoding

Radix-2 Booth encodingは、乗数の二進表現に1が連続した部分がある場合、それが(シフトと)足し算一回と引き算一回で実現できるという事実を利用します。 普通の乗算器で31倍を実現するには16倍、8倍、4倍、2倍、1倍、の5つの数を足しこむ必要があります。 しかし、被乗数の32倍を足し、被乗数の1倍を引く、とすれば実現できます。 Radix-2 Booth encodingは、このような最適化を行うことです。 なぜRadix-2 Booth「encoding」というかというと、この手順は乗数を二進表現(各位が0か1)から冗長二進表現(各位が-1か0か1)に変換していることに相当するからです。

31倍のような特殊な数だけではなく、乗数が0b0011111011110のような1が連続している部分がありさえすれば、この手法を適用することで加算の回数を減らすことができます。 残念ながら、Radix-2 Booth encodingを用いても加算の回数を減らせない乗数がある(例えば0b010101010101とか)ので、ハードウェア実装には向きません。 ハードウェア実装では、最悪に備えて回路を用意しないといけないからです。 手回し式計算器を使ったり、乗算器のないCPUで計算したりする際などには役立つようです。 あるいは、可変レイテンシの直列乗算器を作るためにも使えそうです。

Radix-4 Booth encoding

Radix-4 Booth encodingはこの仕組みを昇華させたものであり、非常に効率的であるためハードウェア乗算器の実装でよく用いられるようです。 この方法では、乗数を二桁ごとに区切って、四進表現(各位が0か1か2か3)から冗長四進表現(各位が-2か-1か0か1か2)に変換します。 ここで、各位としてあり得る数に3がなくなっているのが重要です。 3倍は加算器を使わないと表現できませんが、-2倍、-1倍、0倍、1倍、2倍は加算器なしにシフトだけで簡単に作ることができます(※本当は符号反転に加算器が必要ですが解決する方法があります。後述します)。 これにより、部分積が乗数の二桁ごとに一つとなるため、加算すべき部分積を半分に減らすことができます。 代わりに-2倍・-1倍・0倍・1倍・2倍から選択する回路が必要になりますが、部分積の数が半分になることはWallace Tree約二段分(二段で9本を4本にすることができる)に相当するので、メリットが上回ります。

エンコーディングするためには、a[3:1], a[5:3], a[7:5], a[9:7], ...のように、二桁ずつずれていく重なりのある三桁を取り出す必要があります。 具体的には、a×bを計算するとき、

  • a[2k+1:2k-1]が000なら、bの0倍を出力する
  • a[2k+1:2k-1]が001か010なら、bの 1\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が011なら、bの 2\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が100なら、bの -2\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が101か110なら、bの -1\times2^{2k}倍を出力する
  • a[2k+1:2k-1]が111なら、bの0倍を出力する

とします。

これに加えてa[0]由来の部分積を作る必要があります。 このためには、a[-1]=0だと思ってa[1:(-1)]について同様のエンコーディングを行えばよいです。 この項は-2倍か-1倍か0倍か1倍が出てきます。

また、最上位付近をどこまで足さないといけないかも注意が必要です。 と言ってもほぼ明らかで、最上位ビットa[N-1]の影響があるところまで足せばよいです。

 Nが奇数であれば、a[N]=0だと思ってa[N, N-2]の分まで足せばよいです。 この項は、0倍か1倍か2倍が出てきます。 つまり、負の数は出てきません。

 Nが偶数であれば、a[N+1]=a[N]=0だと思ってa[N+1, N-1]の分まで足せばよいです。 この項は、0倍か1倍が出てきます。 つまり普通のANDゲートで作られる部分積です。

Radix-8 Booth encoding

Radix-8 Booth encodingはさらにこの仕組みを発展させたもので、乗数を三桁ごとに区切って八進表現(各位が0, 1, 2, 3, 4, 5, 6, 7のいずれか)から冗長八進表現(各位が-4, -3, -2, -1, 0, 1, 2, 3, 4のいずれか)に変換します。 冗長八進表現だとシフトだけでは作れない3倍が必要ですが、これは一回作ってしまえば使いまわせるので、十分大きなビット幅の乗算器であれば加算すべき部分積を三分の一に減らせる効果が上回ります。 部分積を三分の一に減らせる効果がRadix-4で部分積を二分の一に減らしたのよりも上回る境界は、3倍を作るコストを考慮に入れても一応 N=6程度なのでほとんどの乗算器に適用可能に見えます。 しかし、3倍を計算するレイテンシは隠蔽できず、またRadix-4とRadix-8の差分は部分積の数が三分の二になること、つまりWallace Tree一段分の効果しかないのでレイテンシを短くする効果はありません。 極力トランジスタ数を減らしたいということではない限り、Radix-4で十分に思えます。

符号反転の方法

Booth encodingを行うと、符号なし乗算器を作っているのに負の数が出てきて面倒ですが、うまく取り扱う方法が知られています。 以下Radix-4を前提とした説明を行いますが、負の数を取り扱うアイデア自体はRadixによらず適用可能です。

さて、「-2倍、-1倍、0倍、1倍、2倍は加算器なしにシフトだけで簡単に作ることができる」と書きましたが、実際にはこれは正しくありません。 符号を反転させるためには、一の補数を取った後に1を足さないといけなくて、ここに加算器が必要だからです。 この問題は、最後に足す1を一種の部分積だとみなすことにより解決できます。

また、このままだと符号拡張のためのビットが部分積として残ってしまい、部分積が四分の三に減るにとどまります。 ここで、符号拡張で生じる部分積は、すべて0(例えば、0b000000)かすべて1(例えば、0b111111)である点に注目します。 あらかじめ0b111111という定数を足す回路を用意しておけば、すべて1の時はそのままでよく、すべて0の時はそれに1を足せばよいです( \overline{S}を足せばよいです)。 また、足す定数は N/2個(合計 N^2/4ビットくらい)出現しますが、事前にその和を計算しておけば、部分積として入力しなければいけない定数は一つ( Nビットくらい)にできます。 この定数は、Radix-4の時0b010101...0101011になります。 なお、この符号拡張をキャンセルする手法は、符号付き乗算器を設計するときにも役立ちます。

さて、この定数の最後は0b011となっていますが、この部分は符号ビットの残り( \overline{S_0})と足し合わせると0b \overline{S_0}S_0S_0になります。 このようにすることで、部分積が多い位ができてしまうことを回避できます。

これを説明したのが図3です。 最後のRadix-4 Booth乗算器は、図解としてよく見るものと一致しています。

図3: Booth encodingで現れる負の数への対処方法。一つ目:-bを~b+1と表現し、足すべき1も部分積とみなす。二つ目:符号拡張で出現するビットを定数+符号ビットの反転( \overline{S})のように表現する。三つ目:  \overline{S_0}と定数の加算を実行することで、部分積の数は変わらないが位あたりの最大部分積数を削減できる。

その他の高速乗算アルゴリズムの適用

相当大きな乗算器でない限り、Karatsuba法やToom–Cook乗算、フーリエ変換を使った乗算などは役に立ちません。 計算量がナイーブ法を下回るビット幅が最も小さいKaratsuba法を例にとって、なぜ役に立たないのかを説明します。

Karatsuba法は、乗算回数を減らす代わりに加減算の回数が増える方法です。 ここで、ハードウェア乗算器においては、1ビット×1ビットの乗算はANDゲート1つでできる一方、これで生じた信号線を1本減らすには5~6ゲートからなる全加算器が必要です。 つまり、ANDゲートより5倍以上大きな全加算器の数が重要であり、その観点からすると乗算と加算はほぼ同じコストです。 これがKaratsuba法がビット数が相当大きくない限り役に立たない直観的な理由です。 部分積が減って全加算器が減る量と前処理・後処理で全加算器が増えてしまう量の大小により、Karatsuba法を使用すべきかが決まります。

必要な全加算器の数を概算してみます。 Radix-4 Booth乗算器では、全加算器が \frac1 2 N^2個程度必要です。 ここに(再帰的にではなく一段階だけ*2)Karatsuba法を適用すると、前処理に 2N個程度、後処理に 6N個程度の全加算器が必要で、代わりに乗算部分の全加算器を \frac3 8 N^2個程度に減らせます。 これが釣り合うのは N=64の時です。 したがって、Radix-4 Booth encodingを使用していれば、64ビット乗算器まででKaratsuba法を使うメリットは全くありません。 同様の計算により、128ビット乗算器であってもRadix-16 Booth encodingを採用すれば十分で、Karatsuba法を採用するメリットはなさそうだということがわかります。 それ以上、例えば256ビット乗算器を作る際には、Karatsuba法が役立つかもしれません。

Toom-2.5はさらに加算が多いので、相当大きなビット数でないとメリットがありません。 Radix-16 Booth encodingを前提にすると、1024ビット乗算器あたりでKaratsuba法に追いつくようです。 Toom-3以降は係数に \frac1 6などが出てきてもう一度乗算(循環小数定数での乗算なので実質的には \Theta(\log N)回の Nビット加算)が必要なので、実用するのはほぼ不可能でしょう。 フーリエ変換などなおさらです。

Wallace Treeの実際の設計

基本

Wallace Treeを設計するには、図4のように各位に何本の信号線があるかを書いたExcelシートを使うと便利です。 この時、各位にある信号線を3(全加算器の入力数)で割って切り捨てた数だけ、その位に全加算器を配置します。 これを何段か繰り返して、すべての位で信号線の本数が2以下になったとき、Wallace Treeが終了します(図5)。

図4: Excelを使ったWallace Treeの設計(その1)

図5: Excelシートを使ったWallace Treeの設計(その2)

アンチパターン

ただし、これだけだと各位の信号線の本数が上位ビットから「2, 2, 2, ..., 2, 2, 3」などとなった時、桁上げ伝搬加算器ができてしまいます。 この時にのみ、半加算器(half adder, HA)を使います。 半加算器は、2入力2出力で信号線の本数が減らないので極力使いたくないですが、アンチパターンが発生した場合だけは使うことでWallace Treeの段数を大幅に削減することができます。 その設計を行ったのが図6です。

図6: 半加算器を使ったWallace Treeのアンチパターンの回避

図7のようにもっと前の段から半加算器を使うことで段数を削減できることがあるようですが(Dadda乗算器と関係している?)、よくわかっていません。

図7: 半加算器を技巧的に配置することで段数を減らした例

11ビット符号なし乗算器の設計

なぜ11ビットを選んだかというと、半精度浮動小数点数(IEEE754のbinary16)の仮数部が(暗黙の1を含めて)11ビットだからです。 基本的には、Wallace TreeとRadix-4 Booth encodingを組み合わせて設計します。 定数との加算を実行し、0b \overline{S_0}S_0S_0とするタイプで設計します。

すると、図8のようになります。

図8: 11ビット乗算器の設計

0b \overline{S_0}S_0S_0を使うタイプで設計したため、一つの位あたりの最大の信号線の本数は六本となり、うまく三段に収めることができました。

Radix-4 Booth乗算器を素直に作ると、各位に所属する部分積の個数の分布がきれいな山形にならずギザギザした感じになるので、半加算器を多く使う必要があるようです。 この問題を解決するため、Booth encoderを工夫し、先に一桁だけ桁上げを行ってしまうことを考えます。 すると、図9のようになります。

図9: 少しだけ工夫した11ビット乗算器の設計

各位に所属する部分積の個数の分布がきれいな山形になったためか、半加算器の数を7から4に減らすことができました。

さて、「先に一桁だけ桁上げを行う」という部分で回路が増えていたら半加算器の数が減っても意味がありません。 そこで、その部分の回路を実際に確認してみます。 先に一桁だけ桁上げを行うRadix-4 Booth encoderの-2倍/-1倍/0倍/1倍/2倍、を選択する回路は以下を出力すればよいです。

a[2:0] 出力   ... out[3] out[2] out[1]その2 out[1]その1 out[0]
000 0b ... 0 0 0 0 0
001 1b ... b[3] b[2] b[1] 0 b[0]
010 1b ... b[3] b[2] b[1] 0 b[0]
011 2b ... b[2] b[1] b[0] 0 0
100 -2b ... ~b[2] ~b[1] ~b[0] 1 0
101 -1b ... ~b[3] ~b[2] ~b[1] ~b[0] b[0]
110 -1b ... ~b[3] ~b[2] ~b[1] ~b[0] b[0]
111 0b ... 1 1 1 1 0

この回路は、図10のように実現することができます。

図10: 先に一桁だけ桁上げを行うRadix-4 Booth encoder

これに対して通常のRadix-4 Booth encoderは図11のような回路となります。

図11: 通常のRadix-4 Booth encoder

ほぼ回路は変わっていないため、先に一桁だけ桁上げを行うことで半加算器を減らすメリットを受けられることがわかりました。

Verilog HDL実装

ここまでの実装をVerilog HDLで記述しました。 この実装は、あり得る222通りの入力すべてに対して正しい答えを返すことをVerilatorにより確認しています。

module multiplier11(lhs, rhs, result);
    input  wire[10:0] lhs;
    input  wire[10:0] rhs;
    output wire[21:0] result;

    function[11:0] Radix_4_Booth_select_012;
        input [1:0] a;
        input[10:0] b;

        case( a )
        2'b000: Radix_4_Booth_select_012 = 12'b0;
        2'b001: Radix_4_Booth_select_012 = { 1'b0, b };
        2'b010: Radix_4_Booth_select_012 = { 1'b0, b };
        2'b011: Radix_4_Booth_select_012 = { b, 1'b0 };
        endcase
    endfunction

    function[14:0] Radix_4_Booth;
        input [2:0] a;
        input[10:0] b;
        Radix_4_Booth[14] =  1'b1;
        Radix_4_Booth[13] = ~a[2];
        Radix_4_Booth[12:1] = Radix_4_Booth_select_012( a[1:0] ^ {2{a[2]}}, b );
        if(a[2])
            Radix_4_Booth[12:2] = ~Radix_4_Booth[12:2];
        Radix_4_Booth[0] = a[2] & ~Radix_4_Booth[1];
    endfunction

    function[15:0] Radix_4_Booth_lsb;
        input [1:0] a;
        input[10:0] b;
        Radix_4_Booth_lsb[15] = ~a[1];
        Radix_4_Booth_lsb[14] =  a[1];
        Radix_4_Booth_lsb[13] =  a[1];
        Radix_4_Booth_lsb[12:1] = Radix_4_Booth_select_012( a[1:0] ^ { 1'b0, a[1] }, b );
        if(a[1])
            Radix_4_Booth_lsb[12:2] = ~Radix_4_Booth_lsb[12:2];
        Radix_4_Booth_lsb[0] = a[1] & ~Radix_4_Booth_lsb[1];
    endfunction

    function[11:0] Radix_4_Booth_odd_msb;
        input [1:0] a;
        input[10:0] b;
        Radix_4_Booth_odd_msb = Radix_4_Booth_select_012( a, b );
    endfunction

    wire[15:0] a0b  = Radix_4_Booth_lsb    (lhs[ 1:0], rhs);
    wire[14:0] a2b  = Radix_4_Booth        (lhs[ 3:1], rhs);
    wire[14:0] a4b  = Radix_4_Booth        (lhs[ 5:3], rhs);
    wire[14:0] a6b  = Radix_4_Booth        (lhs[ 7:5], rhs);
    wire[14:0] a8b  = Radix_4_Booth        (lhs[ 9:7], rhs);
    wire[11:0] a10b = Radix_4_Booth_odd_msb(lhs[10:9], rhs);

/*
    assign result =   { 7'b0, a0b[15:1]        } + { 20'b0, a0b[0], 1'b0 }
                    + { 6'b0, a2b[14:1],  2'b0 } + { 18'b0, a2b[0], 3'b0 }
                    + { 4'b0, a4b[14:1],  4'b0 } + { 16'b0, a4b[0], 5'b0 }
                    + { 2'b0, a6b[14:1],  6'b0 } + { 14'b0, a6b[0], 7'b0 }
                    + {       a8b[14:1],  8'b0 } + { 12'b0, a8b[0], 9'b0 }
                    + {      a10b      , 10'b0 } ;
*/

    function[1:0] full_adder;
        input a;
        input b;
        input c;
        full_adder = { a&b|b&c|c&a, a^b^c };
    endfunction

    wire[1:0] x3  = full_adder(a0b[ 4], a2b[ 2], a2b[ 0]);
    wire[1:0] x4  = full_adder(a0b[ 5], a2b[ 3], a4b[ 1]);
    wire[1:0] x5  = full_adder(a0b[ 6], a2b[ 4], a4b[ 2]);
    wire[1:0] x6  = full_adder(a0b[ 7], a2b[ 5], a4b[ 3]);
    wire[1:0] x7  = full_adder(a0b[ 8], a2b[ 6], a4b[ 4]);
    wire[1:0] x8  = full_adder(a0b[ 9], a2b[ 7], a4b[ 5]);
    wire[1:0] x9  = full_adder(a0b[10], a2b[ 8], a4b[ 6]);
    wire[1:0] x10 = full_adder(a0b[11], a2b[ 9], a4b[ 7]);
    wire[1:0] x11 = full_adder(a0b[12], a2b[10], a4b[ 8]);
    wire[1:0] x12 = full_adder(a0b[13], a2b[11], a4b[ 9]);
    wire[1:0] x13 = full_adder(a0b[14], a2b[12], a4b[10]);
    wire[1:0] x14 = full_adder(a0b[15], a2b[13], a4b[11]);
    wire[1:0] y9  = full_adder(a6b[ 4], a8b[ 2], a8b[ 0]);
    wire[1:0] y10 = full_adder(a6b[ 5], a8b[ 3], a10b[0]);
    wire[1:0] y11 = full_adder(a6b[ 6], a8b[ 4], a10b[1]);
    wire[1:0] y12 = full_adder(a6b[ 7], a8b[ 5], a10b[2]);
    wire[1:0] y13 = full_adder(a6b[ 8], a8b[ 6], a10b[3]);
    wire[1:0] y14 = full_adder(a6b[ 9], a8b[ 7], a10b[4]);
    wire[1:0] y15 = full_adder(a6b[10], a8b[ 8], a10b[5]);
    wire[1:0] y16 = full_adder(a6b[11], a8b[ 9], a10b[6]);
    wire[1:0] y17 = full_adder(a6b[12], a8b[10], a10b[7]);
    wire[1:0] y18 = full_adder(a6b[13], a8b[11], a10b[8]);
    wire[1:0] y19 = full_adder(a6b[14], a8b[12], a10b[9]);

    wire[1:0] z5  = full_adder( x5 [0], x4 [1], a4b[0]);
    wire[1:0] z6  = full_adder( x6 [0], x5 [1], a6b[1]);
    wire[1:0] z7  = full_adder( x7 [0], x6 [1], a6b[2]);
    wire[1:0] z8  = full_adder( x8 [0], x7 [1], a6b[3]);
    wire[1:0] z9  = full_adder( x9 [0], x8 [1], y9 [0]);
    wire[1:0] z10 = full_adder( x10[0], x9 [1], y10[0]);
    wire[1:0] z11 = full_adder( x11[0], x10[1], y11[0]);
    wire[1:0] z12 = full_adder( x12[0], x11[1], y12[0]);
    wire[1:0] z13 = full_adder( x13[0], x12[1], y13[0]);
    wire[1:0] z14 = full_adder( x14[0], x13[1], y14[0]);
    wire[1:0] h15 = full_adder(a2b[14], x14[1],  1'b0 ); // Half Adder
    wire[1:0] z15 = full_adder(a4b[12], y14[1], y15[0]);
    wire[1:0] z16 = full_adder(a4b[13], y15[1], y16[0]);
    wire[1:0] z17 = full_adder(a4b[14], y16[1], y17[0]);
    wire[1:0] z20 = full_adder(a8b[13], a10b[10], y19[1]);

    wire[1:0] w7  = full_adder(z7 [0], z6 [1], a6b[0]);
    wire[1:0] w8  = full_adder(z8 [0], z7 [1], a8b[1]);
    wire[1:0] w9  = full_adder(z9 [0], z8 [1],  1'b0 ); // Half Adder
    wire[1:0] w10 = full_adder(z10[0], z9 [1], y9 [1]);
    wire[1:0] w11 = full_adder(z11[0], z10[1], y10[1]);
    wire[1:0] w12 = full_adder(z12[0], z11[1], y11[1]);
    wire[1:0] w13 = full_adder(z13[0], z12[1], y12[1]);
    wire[1:0] w14 = full_adder(z14[0], z13[1], y13[1]);
    wire[1:0] w15 = full_adder(z15[0], z14[1], h15[0]);
    wire[1:0] w16 = full_adder(z16[0], z15[1], h15[1]);
    wire[1:0] w17 = full_adder(z17[0], z16[1],  1'b0 ); // Half Adder
    wire[1:0] w18 = full_adder(y18[0], z17[1], y17[1]);
    wire[1:0] w19 = full_adder(y19[0],  1'b0 , y18[1]); // Half Adder
    wire[1:0] w21 = full_adder(a8b[14], a10b[11], z20[1]);

    assign result =  { w21[0], w19, w17, w15, w13, w11, w9, w7, z5, x3, a0b[3:1] }
                   + { 1'b0, z20[0], w18, w16, w14, w12, w10, w8, 1'b0, z6[0], 1'b0, x4[0], 1'b0, a2b[1], a0b[0], 1'b0 };
endmodule

full_adderの入力としてどれを選ぶかは改善の余地があります。 今回は信号の名前がなるべくレギュラーになるように配置しましたが、実際は桁上げ信号の配線が短くなるように選ぶとよいでしょう。 また、複合ゲートが使える場合はa&b|b&c|c&aより~(a&b|b&c|c&a)の方が少ないトランジスタ数で作れるので、一部の全加算器は負論理としたほうが良いかもしれません。

符号ビットキャンセル定数伝搬最適化

符号ビットキャンセル定数となんらかの部分積を入力として持つ半加算器は定数伝搬によりほぼノーコスト(NOTゲート一つ)となります。 これを三か所に適用することで、三つの半加算器をなくせるようです。

図12: 符号ビットキャンセル定数を伝搬して半加算器を減らした例

小さな高速加算器を使う

単純に考えると最終段の高速加算器の幅は22ビットとなりますが、下位をWallace Tree実行中に桁上げ伝搬加算器で求めれば高速加算器の幅を縮めることができます。 下位のどこまでを確定させられるかを考えてみると、

  • 20の位は部分積が1つなので既に確定している
  • 21の位はWallace Treeの一段目で確定できる
  • 22の位はWallace Treeの二段目で確定できる
  • 23の位はWallace Treeの三段目で確定できる

となります。 また、最上位ビットは高速加算器のキャリー出力(これは高速加算器の結果が出る一段前に出力される)とのXORで作れるので、ここも高速加算器を通す必要がありません。

ここまでの結果から、高速加算器は24の位から220の位までを取り扱えばよいことがわかりました。 高速加算器の幅が17ビットとなり、微妙に二冪を超えているのでいまいちです。 試行錯誤してみたところ、1bit+2bit加算器と半加算器を適切な位置に配置すれば、24の位もWallace Treeの三段目で確定できることがわかりました。 1bit+2bit加算器は、図13に示すような4ゲートからなる3入力3出力の回路*3で、全加算器よりも小さいのでこれを使うことは問題ないはずです。 out[1]出力の遅延が全加算器一段分よりやや長い気もしますが、ここはクリティカルパスではないので問題ありません。

図13: 1bit+2bit加算器

これにより高速加算器の幅を16ビットとした設計が図14です。

図14: 最後の高速加算器を16ビットとしてみた例

まとめ

  • 乗算器を作るために有用な二つのテクニック
    • Wallace Tree
    • (Radix-4) Booth encoding
  • 実際に11ビット符号なし乗算器を設計した
    • 先に一桁だけ桁上げするRadix-4 Booth encoderを使うことで半加算器を減らせる
    • 符号ビットキャンセル定数を伝搬させることで半加算器を減らせる
    • 1bit+2bit加算器を導入することで、高速加算器の幅を16ビットにすることができた

*1:さらに全加算器の入力位置を取り換えて信号線が短くなる工夫をしてもよいでしょう。

*2:再帰の末端付近ではオーダーが悪いものの定数倍の小さな実装に切り替える」ということをやりたくて、その閾値を求めたいので再帰的に適用しないのが正しいです。

*3:というか半加算器二つを直列につないだものそのものです。

式テンプレートを利用してFastTwoSumを自動生成する

倍精度浮動小数点数doubleは53 bitしか精度がありません。 これを超える精度で計算したいけれど多倍長演算は遅いので避けたい、というときに役立つのがdouble-double(疑似四倍精度)です。 double-doubleはその名の通り、doubleを二つ組み合わせることで高い精度を実現します。 double-douleでは数を二つのdoubleの和として表現します。 これにより、double-dobuleは最低*1でも107 bitの精度を実現します*2

double-double演算ライブラリ

基本

double-doubleの演算は単純ではない(例えば足し算をするとき、ベクトルのように要素ごとに加算すればいいというわけではない)ので、通常はライブラリが用いられます。 C++を利用しているのであれば、kvというヘッダオンリー精度保証付き演算ライブラリに入っているkv::ddを利用するのが最も簡単でしょう。

kv/kv at master · mskashi/kv · GitHub

以下、double-doubleの型をddと書きます。

さて、double-doubleのライブラリを作るときに実装しなければいけない関数は、たくさんあります。 もちろん、dd fma( dd, dd, dd )だけ作ってしまえば、効率はともかく計算することはできます。 それだと無意味な演算が頻発して遅すぎるので、命令数の少ない実装を提供したいということです。

通常の意味で実装が必須となるのは、add関数とmul関数です。 これに加えてfmaも存在すればうれしいですが、kvやcrlibmのライブラリには含まれていませんでした。 十分な精度を持つように実装するのが難しいのでしょうか。

提供されるべきadd関数とmul関数は以下の通りです。

  • dd add( double, double )……TwoSumと呼ばれるdd演算を実装するときの基本パーツなので提供しないことは実質あり得ない
  • dd mul( double, double )……TwoProdと呼ばれるdd演算を実装するときの基本パーツなので提供しないことは実質あり得ない
  • dd add( double, dd )およびdd add( dd, double )
  • dd mul( double, dd )およびdd mul( dd, double )
  • dd add( dd, dd )
  • dd mul( dd, dd )

TwoSumとFastTwoSum

TwoSumは、二つの倍精度浮動小数点数の和をdouble-doubleで求めるアルゴリズムであり、次の実装が知られています。

dd TwoSum( double a, double b ) {
    double x = a + b;
    double tmp = x - a;
    return { x, a - (x-tmp) + (b-tmp) };
}

しかしこれには6演算も必要でコストが高いです。 実は、a >= bであれば、以下の実装(FastTwoSumとかQuickTwoSumとか呼ばれる方式)で問題ありません。

dd FastTwoSum( double a, double b ) {
    double x = a + b;
    return { x, a - x + b };
}

TwoProd

TwoProdは二つの倍精度浮動小数点数の積をdouble-doubleで求めるアルゴリズムです。 それを浮動小数点数演算だけで実現できることは、Dekkerが1971年に示しました。 しかしそのアルゴリズムは非常に技巧的なもので、17 FLOPもかかる高コストなものとなっています。 実は、融合積和演算が使えるのであれば、以下の非常に単純なコードで求めることができます。

dd TwoProd( double a, double b ) {
    double x = a * b;
    return { x, fma( a, b, -x ) };
}

TwoFMAとFastTwoFMA

TwoFMAは倍精度浮動小数点数の融合積和演算の結果をdouble-doubleで求めるアルゴリズムです。 以下のようにできるそうです(参考文献[1])。

dd TwoFMA( double a, double b, double c ) {
    double r   = fma( a, b, c );
    auto[u,du] = TwoProd( a, b );
    auto[x,dx] = TwoSum( c, du );
    auto[y,dy] = TwoSum( u, x );
    double z   = y - r + dy;
    return { r, z + dx };
}

倍精度浮動小数点数の融合積和演算の結果はdouble-doubleで表せるとは限りませんが、triple-double(double三つの和)でなら表せます。 その場合は、最後のz + dxFastTwoSum(z, dx)に置き換えれば目的を達成できます。

fma( a, b, c )をdouble-double精度で求めるのに18もの浮動小数点数演算が必要ですが、a*b-0.5*ccの間の数となる場合は以下の方法が使えます(たぶん)。

dd FastTwoFMA( double a, double b, double c ) {
    double x = fma( a, b, c );
    return { x, fma( a, b, c - x ) };
}

FastTwoFMAからTwoProdとFastTwoSumが導出できる

FastTwoFMAをながめてみると、以下のことがわかります(正確には零の符号とかが怪しいですが、ここでは気にしないでください)。

  • FastTwoSum( a, b )FastTwoFMA( b, 1.0, a )と本質的に同じ計算を行っている
  • TwoProd( a, b )FastTwoFMA( a, b, 0.0 )と本質的に同じ計算を行っている

したがって、原理的にはFastTwoFMAを一度実装してしまえば、適宜1.00.0を代入することでFastTwoSumTwoProdは自動的に導出できそうです。

※TwoFMAは内部でTwoSumとTwoProdを使っていますから、TwoFMAからTwoSumとTwoProdを導くことでTwoSumとTwoProdの実装をサボるというのは無理です。

コンパイラの最適化では自動導出できない

残念ながら、FastTwoFMA( a, b, 0.0 )と書いても、TwoProd( a, b )のように二演算にはならず、余計なごみが残ってしまいます。 これは零にも符号が存在し、(-0.0) + (+0.0)(-0.0)ではなく(+0.0)になるため、単に+ (+0.0)を無視する、というわけにはいかないためです。 この場合に限ってはFastTwoFMA( a, b, -0.0 )と書くことでclangなら見抜いて二演算にしてくれますが、一般には-0.0を利用することは無理です(a * ±0.0みたいな項も無視してもらいたいですが、この項の符号はaが決まらないとわかりません)。

また、FastTwoFMA( a, 1.0, 0.0 ){ a, a-a }を返します。 a - aは基本的には恒等的に0.0ですが、aNaNである場合はそうでないため、それを気にしているようです。

このように、浮動小数点数演算の奇妙な性質が最適化を妨げてしまいます。 コンパイラの最適化に頼るだけでは、FastTwoFMAを実装するだけで効率的なFastTwoSumTwoProdを手に入れることはできないようです。

コンパイルオプションで対処できる

今回の目的においてコンパイラの最適化を妨げているのは、NaNの存在と零の符号の存在です。 これらはそれぞれ、-fno-honor-nansオプションと-fno-signed-zerosオプションを設定することで問題を解決できます。 -fno-honor-nansだけでは、a - a0.0になって-0.0ではないので無駄な命令が残ります。

ただし、ライブラリとして提供するからには、利用者にこのような非標準的なコンパイルオプションを強要するのは避けたいところです。

式テンプレートを利用した自動導出

このような問題が発生するのは、具体的な浮動小数点数である1.00.0を代入してコンパイラの最適化に任せているからです。 そんなことをしなくても、記号的に演算を変形していけばよいはずです。

C++上で式を記号的に取り扱うためには、式自体をオブジェクトとしてとりあつかう式テンプレートというテクニックを使えばよいです。 以前は中間変数へのコピーをなるべく減らす(大仰な)手段の実装に出現することで知られていましたが、ムーブセマンティクスが登場して以降はそのような目的での使用は下火になった感があります*3

以下は、式テンプレートを用いてZeroOneを記号的に取り扱い、それによってdd FastFMA( double, double, double )dd FastFMA2222( dd, dd, dd )だけからいろいろな関数を導出してみた例です。 本質的にはC++03程度の機能で実装できると思われますが、C++17で追加されたif constexprを使うと記述が非常に読みやすくなります(SFINAEを使ったオーバーロードは優先順位がつけにくいため多くの定義を書く必要があります)。

なお、dd FastFMA2222( dd, dd, dd )の実装は、参考文献[2] 715行目 MulAdd22を参考に、問題なさそうなところ*4fmaを使って書き換えたものとしています。

#include <cmath>
#include <utility>

// ==== Basics ====

static constexpr struct Zero {} zero;
static constexpr struct One  {} one ;

template<size_t ID>
struct Var{ double v; double evaluate() const { return v; } };

template<class T, class U>
constexpr bool is_same_impl( T, U ) { return false; }
template<class T>
constexpr bool is_same_impl( T, T ) { return true; }

#define is_same(a,b) is_same_impl( decltype(a){}, decltype(b){} )

// ==== Unary Negate ====

template<class T>
struct Neg { T v; double evaluate() const { return -v.evaluate(); } };

template<class A>
auto operator-( A a ) { return Neg<A>{ a }; };
template<class T>
auto operator-( Neg<T> a ) { return a.v; } // double negation
auto operator-( Zero ) { return zero; } // Zero is the additive identity

// ==== Add ====

template<class A, class B>
struct Add { A a; B b; double evaluate() const { return a.evaluate() + b.evaluate(); } };

template<class A, class B>
auto operator+( A a, B b ) {
    if constexpr( is_same( a, -b ) ) {
        return zero;
    }
    else if constexpr( is_same( a, zero ) ) {
        return b;
    }
    else if constexpr( is_same( b, zero ) ) {
        return a;
    }
    else return Add<A, B>{ a, b };
}

template<class A, class B>
constexpr bool is_same_impl( Add<A, B>, Add<B, A> ) { return true; }

// ==== Sub ====

template<class A, class B>
auto operator-( A a, B b ) {
    if constexpr( is_same( a, zero ) ) {
        return Neg<B>{ b };
    }
    else return a + -b;
}

// ==== FMA ====

template<class A, class B, class C>
struct FMA { A a; B b; C c; double evaluate() const {
    if constexpr( is_same( c, zero ) ) {
        return a.evaluate() * b.evaluate();
    }
    else return std::fma( a.evaluate(), b.evaluate(), c.evaluate() );
} };

template<class A, class B, class C>
auto fma( A a, B b, C c ) {
    if constexpr( is_same( a, zero ) || is_same( b, zero ) ) {
        return c;
    }
    else if constexpr( is_same( a, one ) ) {
        return b + c;
    }
    else if constexpr( is_same( b, one ) ) {
        return a + c;
    }
    else return FMA<A, B, C>{ a, b, c };
}

template<class A, class B, class C>
constexpr bool is_same_impl( FMA<A, B, C>, FMA<B, A, C> ) { return true; }

// ==== Mul ====

template<class A, class B>
auto operator*( A a, B b ) { return fma( a, b, zero ); }

#undef is_same

// ==== Utility ====

using dd = std::pair<double, double>;
namespace std {
    template<> struct tuple_size<Zero> : public integral_constant<std::size_t, 2> {};
    template<> struct tuple_size<One>  : public integral_constant<std::size_t, 2> {};
    template<> struct tuple_element<0, Zero> { using type = Zero; };
    template<> struct tuple_element<1, Zero> { using type = Zero; };
    template<> struct tuple_element<0, One > { using type = One ; };
    template<> struct tuple_element<1, One > { using type = Zero; };
} /* namespace std */
template<std::size_t Index> std::tuple_element_t<Index, Zero> get( Zero ) { return {}; }
template<std::size_t Index> std::tuple_element_t<Index, One > get( One  ) { return {}; }

template<class T, class U>
dd to_dd( std::pair<T, U> expr ) {
    auto [hi, lo] = expr;
    return { hi.evaluate(), lo.evaluate() };
}

template<size_t ID1, size_t ID2>
auto Wrap( dd v ) { return std::pair { Var<ID1>{ get<0>( v ) }, Var<ID2>{ get<1>( v ) } }; }

// ==== Algorithm ====

template<class A, class B, class C>
auto FastTwoFMA( A a, B b, C c ) {
    const auto hi = fma( a, b, c );
    const auto lo = fma( a, b, c - hi );
    return std::pair { hi, lo };
}

template<class A, class B>
auto FastTwoSum ( A a, B b ) { return FastTwoFMA( a, one, b ); }
template<class A, class B>
auto TwoProd( A a, B b ) { return FastTwoFMA( a, b, zero ); }

dd AutoFastTwoFMA( double a, double b, double c ) {
    return to_dd( FastTwoFMA( Var<0>{ a }, Var<1>{ b }, Var<2>{ c } ) );
}
dd AutoFastTwoSum( double a, double b ) {
    return to_dd( FastTwoSum( Var<0>{ a }, Var<1>{ b } ) );
}
dd AutoTwoProd( double a, double b ) {
    return to_dd( TwoProd( Var<0>{ a }, Var<1>{ b } ) );
}

template<class A, class B, class C>
auto FastFMA2222( A a, B b, C c ) {
    auto [ah, al] = a;
    auto [bh, bl] = b;
    auto [ch, cl] = c;
    auto [t1, t2] = TwoProd( ah, bh );
    auto [t3, t4] = FastTwoSum( ch, t1 );
    auto t5 = fma( ah, bl, fma( al, bh, cl + t2 + t4 ) );
    return FastTwoSum( t3, t5 );
}

dd ddAutoMul( dd a, dd b ) {
    return to_dd( FastFMA2222( Wrap<0, 1>( a ), Wrap<2, 3>( b ), zero ) );
}
dd ddAutoFastAdd( dd a, dd b ) {
    return to_dd( FastFMA2222( Wrap<2, 3>( b ), one, Wrap<0, 1>( a ) ) );
}

clang++-12.0.1 -std=c++17 -O2 -mfma -cコンパイルすると、以下の機械語コードが得られます。

0000000000000010 <AutoFastTwoFMA(double, double, double)>:
  10:   c5 f9 28 d9             vmovapd xmm3,xmm1
  14:   c4 e2 f9 a9 da          vfmadd213sd xmm3,xmm0,xmm2
  19:   c5 eb 5c d3             vsubsd xmm2,xmm2,xmm3
  1d:   c4 e2 f9 a9 ca          vfmadd213sd xmm1,xmm0,xmm2
  22:   c5 f9 28 c3             vmovapd xmm0,xmm3
  26:   c3                      ret

0000000000000030 <AutoFastTwoSum(double, double)>:
  30:   c5 fb 58 d1             vaddsd xmm2,xmm0,xmm1
  34:   c5 f3 5c ca             vsubsd xmm1,xmm1,xmm2
  38:   c5 f3 58 c8             vaddsd xmm1,xmm1,xmm0
  3c:   c5 f9 28 c2             vmovapd xmm0,xmm2
  40:   c3                      ret

0000000000000050 <AutoTwoProd(double, double)>:
  50:   c5 fb 59 d1             vmulsd xmm2,xmm0,xmm1
  54:   c4 e2 f9 ab ca          vfmsub213sd xmm1,xmm0,xmm2
  59:   c5 f9 28 c2             vmovapd xmm0,xmm2
  5d:   c3                      ret

0000000000000060 <ddAutoMul(std::pair<double, double>, std::pair<double, double>)>:
  60:   c5 fb 59 e2             vmulsd xmm4,xmm0,xmm2
  64:   c5 f9 28 ea             vmovapd xmm5,xmm2
  68:   c4 e2 f9 ad ec          vfnmadd213sd xmm5,xmm0,xmm4
  6d:   c4 e2 e9 bd e9          vfnmadd231sd xmm5,xmm2,xmm1
  72:   c4 e2 f9 bb eb          vfmsub231sd xmm5,xmm0,xmm3
  77:   c5 db 58 c5             vaddsd xmm0,xmm4,xmm5
  7b:   c5 d3 5c c8             vsubsd xmm1,xmm5,xmm0
  7f:   c5 db 58 c9             vaddsd xmm1,xmm4,xmm1
  83:   c3                      ret

0000000000000090 <ddAutoFastAdd(std::pair<double, double>, std::pair<double, double>)>:
  90:   c5 fb 58 e2             vaddsd xmm4,xmm0,xmm2
  94:   c5 eb 5c d4             vsubsd xmm2,xmm2,xmm4
  98:   c5 eb 58 c0             vaddsd xmm0,xmm2,xmm0
  9c:   c5 fb 58 c1             vaddsd xmm0,xmm0,xmm1
  a0:   c5 fb 58 cb             vaddsd xmm1,xmm0,xmm3
  a4:   c5 db 58 c1             vaddsd xmm0,xmm4,xmm1
  a8:   c5 f3 5c c8             vsubsd xmm1,xmm1,xmm0
  ac:   c5 db 58 c9             vaddsd xmm1,xmm4,xmm1
  b0:   c3                      ret

まず、AutoFastTwoFMAAutoFastTwoSumは三演算、AutoTwoProdは二演算になっているので、この部分は大丈夫そうです。

次に、ddAutoMulですが、非常に小さい値であるa.lo * b.loを無視するタイプのdd乗算の実装が得られています*5。 参考文献[2] 555行目 Mul22と同一なので、これでよさそうです。

最後に、ddAutoFastAddですが、Sloppy*6加算の実装が得られています*7。 参考文献[3] p.5 dd_add・参考文献[4] p.23 Add22・参考文献[5] p.6 SloppyDWPlusDWと同一なので、これで大丈夫そうです。

なお、これらの結果は、このような面倒なことを行わない実装において-fno-honor-nansオプションと-fno-signed-zerosオプションをつけたときと実質的に同じでした。

考察

式テンプレートを使うことで、非標準的なコンパイルオプションに頼らずに、定数伝搬的なことを実現できました。 double-double程度ではライブラリ機能をすべて作るために必要な実装量はそれほど多くないので、それほど役立つようには見えないかもしれません。 しかし、triple-doubleやquad-doubleなどになればadd関数やmul関数の引数の組み合わせは膨大になります。 今回示した方法は、そういう時に役立つかもしれません。

なお、今回はFast系、つまり入力値の大小関係に制限があるものの非分岐な関数を取り扱いました。 一般のadd関数やmul関数はそのような仮定をおけないため、しばしば分岐を含むコードで実装されます。 今回示した方法では、定数伝搬した結果暗黙に成り立つ大小関係を把握することはできません。 しかしそのような場合、実行時にはif文の一方だけが使われることになります。 そのような機械語コードは分岐予測が成功するため、演算ボトルネックであるdouble-double系のプログラムで問題となることは少ないでしょう。

また、今回はFastFMA2222の実装において手動でfmaに変更する最適化を行いました。 しかし、AlderLakeでは浮動小数点数加算が高速であるため、命令数は増えますがfmaを使わずmuladdに分解したほうがレイテンシを削減できることがあります。 fmaの精度が必要ないところについてはターゲットCPUの特性に応じてfmaにしたりmuladdにしたり、ということもできたらおもしろそうです。 変数束縛を超えて勝手にfmaにしてはいけないですが、普通の式テンプレートでは変数に束縛されているのか式の途中なのかを区別することができません。 しかし、C++11で追加された右辺値参照を左辺値参照と区別して使うことで、これができるかもしれません。

最後に、今回の手法でも無駄な命令を取り除けないことがあります。 例えば、TwoProdTwoSumを使わずに加算と乗算だけで実装したTwoFMAを使ってTwoSumTwoProdを導出することを考えます。 TwoSumの方は導出できるのですが、TwoProdの方は無駄命令だらけになってしまいます。 これは、[d, du] = TwoProd( a, b )のようになっている時、d + du == dなのですが、これを使った定数伝搬ができないことが原因です。 これを自動で実現するのは相当困難なように思えます。 ただ、そういったコードは正規化にこだわる実装には出てきますがSloppyな実装にはあまり出てこないので、Fast系を導出する場合はあまり問題にならないのかもしれません。

まとめ

  • doubleで精度が足りない時、double-dobleを使うと手軽で便利
  • double-doubleのライブラリを作るのは面倒(※普通は自分で作らず、kv/kv at master · mskashi/kv · GitHubなどを使うとよいでしょう)
  • 式テンプレートを使うことで、double-double用のfma関数を作るだけで効率の良いaddmulを自動導出できた
  • triple-doubleなどのライブラリを作るときにも役立つと期待

参考文献

  • [1] Boldo, Sylvie and Muller, Jean-Michel, "Exact and Approximated Error of the FMA", IEEE Transactions on Computers 60 157–164 (2011).
  • [2] David Defour, Catherine Daramy, Florent de Dinechin, Matthieu Gallet, Nicolas Gast, Christoph Lauter, Jean-Michel Muller, "cr-libm, a portable, efficient, correctly rounded mathematical library" (2010).
  • [3] 柏木 雅英「double-double 演算、double-double 区間演算に関するまとめ」(2021). http://verifiedby.me/kv/dd/dd.pdf crlibm/crlibm_private.h at master · taschini/crlibm · GitHub
  • [4] Catherine Daramy-Loirat, David Defour, Florent de Dinechin, Matthieu Gallet, Nicolas Gast, Christoph Quirin Lauter, Jean-Michel Muller, "CR-LIBM A library of correctly rounded elementary functions in double-precision" (2006). https://ens-lyon.hal.science/ensl-01529804/file/crlibm.pdf
  • [5] Mioara Maria Joldes, Jean-Michel Muller, Valentina Popescu, "Tight and rigorous error bounds for basic building blocks of double-word arithmetic", ACM Transactions on Mathematical Software 44(2), 1–27 (2017). https://hal.science/hal-01351529v3/document

*1:1.0 + 0x1.p-1000とかを表現できるので、1000 bit以上の精度を持つことが原理的にはあります。
参考:double-double演算が異常に高精度になる!? - kashiの日記

*2:53 bit二つで107 bitになるのは不思議ですが、これは「下位部分」を担当するdoubleの符号ビットが有効利用されるからです。

*3:調べてみると、無名関数の実装にも利用されていたようです。

*4:参考文献[5] p.6 DWTimesDW2では c_{l3} = RN( RN( c_{l1} + x_ly_h) + x_hy_l)とできそうなのにそうしていません。 c_{l2} = RN( RN( x_l \times y_h) + x_hy_l)部分はよくても、 c_{l1}と足すときにもfmaを使うのは問題があるのかもしれません。

*5:正規化されている保証がなくてよい、つまり下位要素の絶対値が上位要素の絶対値の0.5 ULP以下でなくてよいなら、五演算まで減らせるようです(Sloppy乗算)。

*6:Sloppyは「ずさんな」という意味です。double-doubleやquad-doubleなどマルチコンポーネント系の計算手法において、高速化のために精度が犠牲となる実装のことを指します。

*7:この実装はSloppyという名前がついていますが、必ず正規化されたddを返します。しかし桁落ちが発生すると、本来は非零値を返すべきところで零を返すことがあります。Sloppyというのは、この種の誤差が生じうることを意味しているようです。