高速な完全精度 expf 関数の作り方

結構前に、完全精度(すべての入力に対して最近接丸めを行う) expf 関数の作り方を明らかにしました(完全精度 expf 関数の作り方 - よーる)。

今回は、その高速化を行います。積和演算がそれなりに高速なCPUを前提とします。

高速化の技法

多くの高速化の技法は、高速な倍精度指数関数expの実装が参考になります。ただし、この時代のIntel CPUには積和演算命令がなかったため、積和演算命令を考慮していない高速化技法になっています*1

テイラー展開テーブル引きを使う

 \exp(x) = \exp(s+t) = \exp(s)\exp(t)と分解します。ここで、 \exp(s)は簡単に求められるように sを選びます。かつ、 tはその絶対値がなるべく小さくなるように選びます。  \exp(t)はテイラー近似により求める予定なので、 |t|が小さい方が速く(少ない項数で)収束してうれしいという目論見です。

 s \log2の倍数であれば、 \exp(s)は二のべき乗になり、非常に簡単に求められます。 しかし、 |t|は最大で0.35程度となるため、テイラー近似の項数が多く必要になります。

 s = \frac{\log2}2の時の \exp(s)の値(つまり \sqrt{2})を持っておけば、 s \frac{\log2}2の倍数の場合に対応できるため、 |t| \lt \frac{\log2}{4}にできます。

一般に、 s = \frac{k\log2}N (0\le k \lt N)の時の \exp(s)の値(つまり 2^{\frac k N})を持っておけば、 s \frac{\log2}Nの倍数の場合に対応できて、 |t| \lt  \frac{\log2}{2N}にできます。

 Nを二のべき乗にしておくと、 kを計算するための Nでの剰余演算が、上位ビットをゼロにするだけとなるため、計算が楽になります。

 \exp(t)ではなく \exp(t)-1を求める

先ほど \exp(t)をテイラー近似で求めると言いましたが、実際には \exp(t)-1を求めます。そして、 \exp(x) = \exp(s+t) = \exp(s)\exp(t)ではなく、 \exp(x) = \exp(s+t) = \exp(s) + \exp(s)\left(\exp(t)-1\right)のように求めます。

 \exp(t)-1を求めるのは遠回りのように思われるかもしれません。しかし、そのテイラー展開 x + \frac{x^2}2 + \frac{x^3}6 + \cdotsのようになっているため、むしろ \exp(t)より計算しやすい値です。 そもそも、 \exp(t)テイラー展開 1 + x + \frac{x^2}2 + \frac{x^3}6 + \cdotsのようになっていることを考えれば、これは \exp(t)を求める際の途中結果です。 1.0を加算して \exp(t)を求めた後 \exp(s)との積を計算する手順を一回の積和演算に統合できるので高速化につながります。

しかも、 \exp(t)-1だと0.000012345678のように精度高く保持できたところ、 \exp(t)としてしまうと1.0000123のように丸められて精度が低下してしまうのを防ぐ効果もあります。

大きな数を足して丸める

倍精度浮動小数点数仮数部は52bit(ケチ表現なので実際は53bit)なので、252~253の範囲の数になるようにすると小数点以下がまるめられます。

 x \frac N {\log2}の絶対値はそれほど大きくないため、 R = 1.5\times 2^{52}を足すことにより常にその範囲の数にできます。

つまり、 Rを足した後 Rを引けば、小数点以下を丸めることができることになります。

また、 R + x \frac N {\log2}を計算すると R + Nsに丸められますが、その倍精度浮動小数点数表現の下51bitは Nsの二の補数表現と等価になっています。CPUによりますが、これは高速化に役立てられることが多いです。

 t = x - s\log2を精度良く求める

 \exp(t)の絶対誤差は tの絶対誤差にほぼ等しくなります。そのため tをそれなりに正確に求める必要がありますが、 x s\log2は同じくらいの値であるため、数式通り計算してしまうと桁落ちが発生して精度が不足します。

この式は積和演算の形になっているため積和演算命令を利用すれば精度不足が解消できそうにも思えます。確かに精度の向上は得られますが、今度は \log2の精度不足が壁となります。 \log2の絶対誤差の s倍が tの絶対誤差となるため、 \log2は倍精度浮動小数点数の53bit精度ですら不足しているのです*2

 \log2を倍倍精度表現(二つの倍精度浮動小数点数ln2hln2lの和)として表せば精度の問題は解決します。しかし、今度は積和演算が使えない形になるため、中間結果に丸め誤差が発生します。どのくらいの丸め誤差が発生したかの評価は可能です(実際、倍倍精度計算では常に行っています)が、高コストです。

kの絶対値は218を超えないことを利用すると、中間結果に丸め誤差が発生しないようにできます。k * ln2hの計算に丸め誤差は発生しなくするためには、ln2hの下位18bitを0にしておきます。xk * ln2hの値は近いはずなので丸め誤差は発生しません。というよりむしろ桁落ちするのですが、ここまでの計算で一度も丸め誤差を発生させていないため、無害な桁落ちです。

一方、ln2hの下位18bitを0にする必要から \log2の精度が89bitに低下しています。しかし、実際には倍倍精度表現の107bit精度が必要になる局面は存在せず、問題になりません。

このようにすることで tを最後の一桁まで正確に求めることができます。

実際のコード

#include <cmath>
#include <cstdint>
#include <cstring>

template<class To, class From>
To bit_cast( const From& from ) noexcept {
    To to;
    static_assert( sizeof to == sizeof from );
    std::memcpy( &to, &from, sizeof to );
    return to;
}

namespace {
    double expm1_taylor3( double t1 ) noexcept {
        constexpr double C2 = 1.0 / 2.0;
        constexpr double C3 = 1.0 / 6.0;
        const double s1 = std::fma( C3, t1, C2 );
        const double t2 = t1 * t1;
        return std::fma( s1, t2, t1 );
    }

    double exp_table( uint64_t s ) noexcept {
        constexpr double b1table[32] {
            0x1.0000000000000p+0,
            0x1.059b0d3158574p+0,
            0x1.0b5586cf9890fp+0,
            0x1.11301d0125b51p+0,
            0x1.172b83c7d517bp+0,
            0x1.1d4873168b9aap+0,
            0x1.2387a6e756238p+0,
            0x1.29e9df51fdee1p+0,
            0x1.306fe0a31b715p+0,
            0x1.371a7373aa9cbp+0,
            0x1.3dea64c123422p+0,
            0x1.44e086061892dp+0,
            0x1.4bfdad5362a27p+0,
            0x1.5342b569d4f82p+0,
            0x1.5ab07dd485429p+0,
            0x1.6247eb03a5585p+0,
            0x1.6a09e667f3bcdp+0,
            0x1.71f75e8ec5f74p+0,
            0x1.7a11473eb0187p+0,
            0x1.82589994cce13p+0,
            0x1.8ace5422aa0dbp+0,
            0x1.93737b0cdc5e5p+0,
            0x1.9c49182a3f090p+0,
            0x1.a5503b23e255dp+0,
            0x1.ae89f995ad3adp+0,
            0x1.b7f76f2fb5e47p+0,
            0x1.c199bdd85529cp+0,
            0x1.cb720dcef9069p+0,
            0x1.d5818dcfba487p+0,
            0x1.dfc97337b9b5fp+0,
            0x1.ea4afa2a490dap+0,
            0x1.f50765b6e4540p+0,
        };

        constexpr double b2table[32] {
            0x1.0000000000000p+0,
            0x1.002c605e2e8cfp+0,
            0x1.0058c86da1c0ap+0,
            0x1.0085382faef83p+0,
            0x1.00b1afa5abcbfp+0,
            0x1.00de2ed0ee0f5p+0,
            0x1.010ab5b2cbd11p+0,
            0x1.0137444c9b5b5p+0,
            0x1.0163da9fb3335p+0,
            0x1.019078ad6a19fp+0,
            0x1.01bd1e77170b4p+0,
            0x1.01e9cbfe113efp+0,
            0x1.02168143b0281p+0,
            0x1.02433e494b755p+0,
            0x1.027003103b10ep+0,
            0x1.029ccf99d720ap+0,
            0x1.02c9a3e778061p+0,
            0x1.02f67ffa765e6p+0,
            0x1.032363d42b027p+0,
            0x1.03504f75ef071p+0,
            0x1.037d42e11bbccp+0,
            0x1.03aa3e170aafep+0,
            0x1.03d7411915a8ap+0,
            0x1.04044be896ab6p+0,
            0x1.04315e86e7f85p+0,
            0x1.045e78f5640b9p+0,
            0x1.048b9b35659d8p+0,
            0x1.04b8c54847a28p+0,
            0x1.04e5f72f654b1p+0,
            0x1.051330ec1a03fp+0,
            0x1.0540727fc1762p+0,
            0x1.056dbbebb786bp+0,
        };

        const double b1 = b1table[s>>5&31];
        const double b2 = b2table[s&31];
        const uint64_t exponent = (s >> 10) << 52;
        return bit_cast<double>( bit_cast<uint64_t>( b1 * b2 ) + exponent );
    }
}

float exact_expf( float x ) noexcept {
    if( x < -104.0f ) { return 0.0f; }
    if( x > 0x1.62e42ep+6f ) { return HUGE_VALF; }

    constexpr double R    =  0x3.p+51f;
    constexpr double iln2 =  0x1.71547652b82fep+10;
    constexpr double ln2h =  0x1.62e42fefc0000p-11;
    constexpr double ln2l = -0x1.c610ca86c3899p-47;

    const double k_R     = std::fma( static_cast<double>(x), iln2, R );
    const double k       = k_R - R;
    const double t       = std::fma( k, -ln2l, std::fma( k, ln2h, static_cast<double>(x) ) );
    const double exp_s   = exp_table( bit_cast<uint64_t>(k_R) );
    const double expm1_t = expm1_taylor3( t );
    const double exp_x   = std::fma( exp_s, expm1_t, exp_s );
    return static_cast<float>( exp_x );
}

精度

上記実装は、以前作った完全精度expf関数と出力が一致しました。以前作った完全精度expf関数は精度保証演算ライブラリkvによって完全精度であることを確認してあります。よってこの実装も完全精度です。

inf付近での丸めに関する挙動は、以前(IEEE754浮動小数点数の丸めに関するメモ - よーる)調べたように、最近接丸めを行うためにはinfがあたかも0x1.p+128であるかのように取り扱えばよいです。 以前作った完全精度expf関数の記事では、inf付近の挙動をどうすべきかわからないと書きましたが、正しい動作をしていたということになります。

さらに高速化できそうな点

さらなる改善の余地は残されているような気がしますが、これ以上はチューニングの域でしょう。

  • tを求める部分
    • 積和演算を使った高精度演算を利用する
    • ln2の精度が53bitで足りるようになるような分割方法にする
  • exponentを求める部分
    • Rを変更すれば下駄を履かせられるのでexp_table(s)の計算が浮動小数レジスタ上で計算できるようになる

(7/14頃に書いて放置していた記事の加筆)

*1:実際、積和演算命令が使えるCPUであれば、fmath::expはexp関数より低速であることが多いです。これはfmath::expが積和演算命令を含まないアセンブリで書かれていることが原因です。高級言語で書かれていれば、コンパイラが積和演算命令を使うという最適化をかけられるのですが、アセンブリで書かれていると時代に取り残されてしまうのですね……。

*2:sが大きな値になる浮動小数点数は少ないため、運が良ければ回避できるかと思いましたが、ぎりぎりダメでした。