もっと高速な完全精度 expf 関数の作り方

何年か前に、高速な完全精度(すべての入力に対して最近接丸めを行う) expf 関数の作り方を明らかにしました(高速な完全精度 expf 関数の作り方 - よーる)。

その中で、 t = x - s\log2を求める際に使う \log2は、倍精度浮動小数点数の53bit精度ですら精度が不足しているので、二つの倍精度浮動小数点数ln2hln2lの和で表す必要がある、と書きました。

これ自体は正しいのですが、ここの誤差に由来してexpf(x)が正しく求められないxは二つしかありません。 よって、他の部分のアルゴリズム部分にわずかな(数学的な正当化ができないような姑息な)変更を加えることで、全体として返す結果が正しくなるようにできる可能性があります。

実際、いくつか実験してみると、tの誤差を補償することで完全精度とできる実装を手に入れることができました。

以下、試したことを書いてみます。

多項式近似をいじってみる

この方針はうまくいきませんでした。 というのも、多項式近似はあらゆる入力で使われるため、tの誤差を補償しようとすると、他の入力の時に正しくない値を返してしまうことにつながるからです。

テーブルをいじってみる

この方針が正しかったです。

tの誤差に由来してexpf(x)が正しく求められないxは、x1 = -0x1.d2259ap+3fx2 = 0x1.112856p+6fだけです。 x1ではb1table[31]b2table[16]を、x2ではb1table[16]b2table[21]を使います。 これらのテーブルの値を、expf(x)が目的の値になるようにわずかにずらすことを考えます。 ずらす量 \varepsilonを二つのテーブルに分散させれば、

  • ほとんどの入力に対しては、たかだか \varepsilon/2しかずれない
  • 運悪く同じテーブルの組み合わせを使う入力に対しては \varepsilonずれてしまうけれど、そういう入力は全体の1/1024しかないので、運が悪くなければ完全精度自体は達成可能

とできます。 なんだかブルームフィルタみたいです。

実際にやってみると、x3 = -0x1.e1dbe2p-8fで問題が発生します。 この入力は元々運よくぎりぎり正しい丸め方向になっていた入力で、b1table[31]b2table[21]を使っています。

x1に対しては5ULP増やす必要があり、x2に対しては8ULP減らす必要があります。 つまり、x1x2では、動かすべき方向が逆で、x2の方がたくさんうごかす必要があります。 よって、ずらす量を二つのテーブルに均等に分散させると、x3の計算結果がずれてしまい、今回の場合は丸め境界をまたいでしまうようです。 そこで、x2の調整のためには、b2table[21]はあまりずらさず、かわりにb1table[16]をたくさんずらす、とする必要があるようです。

実際のコード

#include <cmath>
#include <cstdint>
#include <cstring>
#include <bit>

namespace {
    double expm1_taylor3( double t1 ) noexcept {
        constexpr double C2 = 1.0 / 2.0;
        constexpr double C3 = 1.0 / 6.0;
        const double s1 = std::fma( C3, t1, C2 );
        const double t2 = t1 * t1;
        return std::fma( s1, t2, t1 );
    }

    double exp_table( uint64_t s ) noexcept {
        constexpr double b1table[32] {
            0x1.0000000000000p+0,
            0x1.059b0d3158574p+0,
            0x1.0b5586cf9890fp+0,
            0x1.11301d0125b51p+0,
            0x1.172b83c7d517bp+0,
            0x1.1d4873168b9aap+0,
            0x1.2387a6e756238p+0,
            0x1.29e9df51fdee1p+0,
            0x1.306fe0a31b715p+0,
            0x1.371a7373aa9cbp+0,
            0x1.3dea64c123422p+0,
            0x1.44e086061892dp+0,
            0x1.4bfdad5362a27p+0,
            0x1.5342b569d4f82p+0,
            0x1.5ab07dd485429p+0,
            0x1.6247eb03a5585p+0,
            0x1.6a09e667f3bc7p+0, // pow(2, 16./32) = 0x1.6a09e667f3bcdp+0 から6ULPずらした
            0x1.71f75e8ec5f74p+0,
            0x1.7a11473eb0187p+0,
            0x1.82589994cce13p+0,
            0x1.8ace5422aa0dbp+0,
            0x1.93737b0cdc5e5p+0,
            0x1.9c49182a3f090p+0,
            0x1.a5503b23e255dp+0,
            0x1.ae89f995ad3adp+0,
            0x1.b7f76f2fb5e47p+0,
            0x1.c199bdd85529cp+0,
            0x1.cb720dcef9069p+0,
            0x1.d5818dcfba487p+0,
            0x1.dfc97337b9b5fp+0,
            0x1.ea4afa2a490dap+0,
            0x1.f50765b6e4542p+0, // pow(2, 31./32) = 0x1.f50765b6e4540p+0 から2ULPずらした
        };

        constexpr double b2table[32] {
            0x1.0000000000000p+0,
            0x1.002c605e2e8cfp+0,
            0x1.0058c86da1c0ap+0,
            0x1.0085382faef83p+0,
            0x1.00b1afa5abcbfp+0,
            0x1.00de2ed0ee0f5p+0,
            0x1.010ab5b2cbd11p+0,
            0x1.0137444c9b5b5p+0,
            0x1.0163da9fb3335p+0,
            0x1.019078ad6a19fp+0,
            0x1.01bd1e77170b4p+0,
            0x1.01e9cbfe113efp+0,
            0x1.02168143b0281p+0,
            0x1.02433e494b755p+0,
            0x1.027003103b10ep+0,
            0x1.029ccf99d720ap+0,
            0x1.02c9a3e778063p+0, // pow(2, 16./1024) = 0x1.02c9a3e778061p+0 から2ULPずらした
            0x1.02f67ffa765e6p+0,
            0x1.032363d42b027p+0,
            0x1.03504f75ef071p+0,
            0x1.037d42e11bbccp+0,
            0x1.03aa3e170aafcp+0, // pow(2, 21./1024) = 0x1.03aa3e170aafep+0 から2ULPずらした
            0x1.03d7411915a8ap+0,
            0x1.04044be896ab6p+0,
            0x1.04315e86e7f85p+0,
            0x1.045e78f5640b9p+0,
            0x1.048b9b35659d8p+0,
            0x1.04b8c54847a28p+0,
            0x1.04e5f72f654b1p+0,
            0x1.051330ec1a03fp+0,
            0x1.0540727fc1762p+0,
            0x1.056dbbebb786bp+0,
        };

        const double b1 = b1table[s>>5&31];
        const double b2 = b2table[s&31];
        const uint64_t exponent = (s >> 10) << 52;
        return std::bit_cast<double>( std::bit_cast<uint64_t>( b1 * b2 ) + exponent );
    }
}

float exact_expf( float x ) noexcept {
    if( x < -104.0f ) { return 0.0f; }
    if( x > 0x1.62e42ep+6f ) { return HUGE_VALF; }

    constexpr double R    =  0x3.p+51f;
    constexpr double iln2 =  0x1.71547652b82fep+10;
    constexpr double ln2  =  0x1.62e42fefa39efp-11;

    const double k_R     = std::fma( static_cast<double>(x), iln2, R );
    const double k       = k_R - R;
    const double t       = std::fma( k, -ln2, static_cast<double>(x) );
    const double exp_s   = exp_table( std::bit_cast<uint64_t>(k_R) );
    const double expm1_t = expm1_taylor3( t );
    const double exp_x   = std::fma( exp_s, expm1_t, exp_s );
    return static_cast<float>( exp_x );
}