AVX-512の機能を使ったlogf(x)の実装(その2)

前回(AVX-512の機能を使ったlogf(x)の実装(その1) - よーる)の記事で、AVX-512でベクトル実行することを前提としたlogfの高速実装を作りました。 速度を重視しつつもできる限り精度に気を付けた結果、ほとんどの入力に対して誤差を1.5ULP未満とできました。 ただし、対数関数の定義域に属する2139095039個の単精度浮動小数点数のうち、1502個についてだけは誤差が1.5ULPを超えてしまいました(それでも最大誤差は2.2ULP未満です)。 今回は誤差の原因を詳細に解析し、その誤差の要因を避けることのできる係数を発見することで、最大誤差1.5ULPを保証するlogfの高速実装を作ります。

前回の実装の確認

前回の実装の本体を示します。 補助関数まで示すと冗長になるので省略しました(vgetexpps浮動小数点数を仮想的に*1正規化して指数部を取り出す関数、vgetmantps浮動小数点数を仮想的に正規化して仮数部を取り出す関数、vpermpsは16エントリの表引きを行う関数です)。

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

ようするに、xを仮想的に正規化したときの指数部expo e)とxを仮想的に正規化したときの仮数部から決まる16通りの定数invs \sigma_k = s_k^{-1})を使って、 \log x = e\log2 + \log s_k + \log\frac x{2^es_k} のように計算しているということです。

誤差の要因の確認

x1.0fから十分離れている時

 \left|\log\frac x {2^es}\right| |\log x| より十分小さく、 \log\frac x  {2^es} の計算で生じる誤差が最終結果に与える影響は十分小さいです。 よって素直に計算すれば大きな誤差が発生することはありません(実際、1.5ULPを超える誤差が発生する1502ケースはいずれも  \exp(-2^{-6}) \lt x \lt \exp(2^{-5}) です)。

x1.0fに近いとき

expo0.0fとなります。よって、std::fma( expo, 0x1.62e430p-1f, logs )logsになります。 この時、最終結果の誤差は主に以下の四つの誤差の和からなります。

  1. 最後のstd::fma丸め誤差
  2. logs \log sの差
  3. t丸め誤差
  4. poly \log(1+t)/tの差のt

このうち、一つ目は明らかに避けることができません(0.5ULP)。 二つ目は、前回説明したように、 \sigma_k を適切に選ぶことで十分小さくできます。 三つめは、invs1.0fでない時は普通にやると避けることができません(0.5ULP)。

さて、残る「poly \log(1+t)/t の差のt倍」はどうでしょうか。

poly \log(1+t)/t の差」に最も効いてくるのはpolyの計算の最後の丸め誤差(0.5ULP)です。 この0.5ULPの誤差が厄介で、polyが1より小さければ  2^{-25} なのですが、polyが1より大きいと  2^{-24} になってしまいます。  2^{-25} の誤差のt倍は、t仮数部が2に近いときtに対して0.5ULPであり、 2^{-24} の誤差のt倍は、t仮数部が2に近いときtに対して1.0ULPです。 実際にはこれに加えて計算途中の丸め誤差の影響と近似誤差が追加されます。

 k = 0の時

 t = x - 1 になり、tを求める部分では丸め誤差が発生しません。 polyが1より大きくて丸め誤差が大きくなってしまうのは  x \lt 1 の時で、t仮数部が2に近ければlogf(x)の指数部とtの指数部は同じになります。 その時、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差が0ULP、三つ目の誤差が0ULP、四つ目の誤差が1.0ULP以上、で合わせて1.5ULP以上になります。 すべての最悪ケースを同時に引くことはなさそうですが、それでも1.5ULPを保証するのはかなり厳しそうに見えます。

 k \ne 0の時

logf(x)の指数部に比べてtの指数部が十分小さければ、「x1.0fから十分離れている時」と同様の論理で、素直に計算すれば大きな誤差が発生することがありません。

tの指数部がlogf(x)の指数部と同じになる場合を考えます。 これは例えば、 x = 0.984 付近の時に  \sigma_{15} = 32/31 t = x\sigma_{15} -1 = 1.00748\times2^{-6} \log x = -1.03228\times2^{-6} になるような場合のことです。 すると、最終結果に与える影響は、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が0.5ULP以上、で合わせて1.5ULP以上になります。 tの指数部が大きいということはつまり多項式近似の端の方なので、近似誤差を含めて1.5ULP未満に抑えることは難しそうです。 さらに言えば、 x = 1.0315 付近の時などはもっとひどくて、一つ目の誤差が0.5ULP、二つ目の誤差がほぼ0ULP、三つ目の誤差が0.5ULP、四つ目の誤差が1.0ULP以上、で合わせて2.0ULP以上になって、どうにもなりません(実際、2.0ULPを超える誤差が発生する5ケースは全てこれです)。

このような問題が起こるのは、 \sigma_k が切り替わる境界がlogfの指数部やtの指数部が切り替わる境界とずれているからです。 それをまとめたのが以下の表です。

本来あるべき境界 実際の境界 問題点
 \sigma_{15} \sigma_0 の境界  (1+2^{-6})s_{15}=0.983887  \frac{63}{64}=0.984375  k = 15 の負担が大
 \sigma_0 \sigma_1 の境界  \exp(2^{-5})=1.031743  \frac{33}{32}=1.03125  k = 1 の負担が大

このうち、 s_{15}がかかわるほうの問題は s_{15}を適切にとることで解消できます。

戦略

以下の二つの戦略により、上記問題を解決します。

  • idxfを求める部分で浮動小数点数加算ではなく浮動小数点数積和演算を用いることで、命令数を増やさずに境界をずらす
  • 多項式近似の係数を工夫して、 \log(1+t)/t \gt 1 となる範囲( t \lt 0)では極力近似誤差が0に近くなるようにする

さらに、上記戦略をうまく働かせるため、 k ごとに  t としてあり得る範囲を調整します。 これは  \sigma_k \left(1+\frac{k}{16}\right)^{-1} から大胆にずらすことで実現します。

終結果に与える最大誤差が以下のようになることを目指して最適化します。

誤差要因  k = 0 の場合  k = 15 または  k = 1 の場合
最後のstd::fma丸め誤差 0.5ULP 0.5ULP
logs \log sの差 0ULP ほぼ0ULP
t丸め誤差 0ULP 0.5ULP
poly \log(1+t)/tの差のt 1.0ULP 0.5ULP

idxfを求める部分の変更

 \sigma_0 \sigma_1 の境界が  \exp(2^{-5})=1.031743 より大きくなり、かつ  \sigma_{15} \sigma_0 の境界が  \exp(-2^{-6})=0.984496 以下となれば良いです。

実際は近似多項式の係数と同時に最適化するのですが、結論から言うと以下のように設定すればうまくいきました。

    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );

このようにすることで、 k = 1 かつ  t \lt 0 の場合で誤差が大きくなる問題を防ぎます。 代わりに  k = 0 の場合に取り扱わなければいけないtの範囲が  -0.016130 \lt t \le 0.03125 から  -0.011012 \lt t \lt 0.036084 になります。 負の側は範囲が狭くなり、正の側は範囲が広くなっています。 負の側は精度の要求が厳しく、正の側は精度の要求が緩いので、それと合致しています。 単に  k = 1 かつ  t \lt 0 の場合の回避だけでない効果をここに求めているようです。

 \sigma_k の選択

以下を満たしていれば意外と自由に取れます。

  •  \sigma_{15}
    • tの上限が  2^{-7} 以下になるようにする
      • 境界変更により  -2^{-6} \lt \log(x) \lt -2^{-7} となる付近を取り扱わないといけないので、  2^{-7} \lt t だとpoly \log(1+t)/t の差のt倍が最終結果に対して0.5ULP以上になってしまう
      • 多項式近似の区間の幅を制約するので、tの上限はなるべく  2^{-7} に近いほうが良い
  •  \sigma_1
    • tの上限が  k = 0 の場合のtの上限以下になるようにする
    • tの下限が  k = 15 の場合のtの下限以下になるようにする

多項式近似の工夫

上記  \sigma_k の選択の下、以下を守れば全体が1.5ULP保証になります。 実際には、ほんの少し守れていなくても、誤差要因全てがそろうことはないので、結果的に1.5ULP保証できる実装が手に入ります。

  • 多項式近似の誤差が  t=-2^{-N}(N=7,8,9,\dots) でほぼ0になるようにする
    • poly \log(1+t)/t の差のt倍」はt仮数部が2に近いときtのULPに対して大きくなるので、 t が二の冪乗の時が重要
    •  k = 0 において、t仮数部が2に近くて t \lt 0ならばlogf(x)のULPはtのULPと同じ
    • poly丸め誤差だけで最終結果に与える影響が1.0ULPあるので、多項式近似誤差が少しでもあると1.0ULPを超えてしまう
  • poly \log(1+t)/t の差のt倍について、
    • 区間の左端  t = L において  2^{-29} 未満になるようにする
      •  k = 0 では使われず、 k = 15 k = 1 で使われる
      • その時  2^{-5} \lt |\log x| \lt 2^{-4}なので、logf(x)に対しての0.5ULPは 2^{-29}
    • 区間の右端  t = R において  2^{-28} 未満になるようにする
      •  k = 0 で使われ、その時  2^{-5} \lt \log x \lt 2^{-4}なので、logf(x)に対しての1.0ULPは 2^{-28}
      •  k = 1 で使われ、その時  2^{-4} \lt \log x \lt 2^{-3}なので、logf(x)に対しての0.5ULPは 2^{-28}
        • ただし、 k = 1 の場合にはtの指数部がlogf(x)の指数部よりも一つ小さくてt丸め誤差が最終結果に与える影響が0.25ULPしかないので実際には問題とならない
    •  t \in [\exp(2^{-N-1})-1, \exp(2^{-N})-1)(N=5,6,7,\dots) において  2^{-N-24} 未満になるようにする
      •  k = 0 において重要で、その区間において  2^{-N-1} \le \log x \lt 2^{-N}なので、logf(x)に対しての1.0ULPは 2^{-N-24}

これを満たす近似多項式は、以下のようになります(限界までは最適化していません)。

    const float C4 = -0x1.fb4ef8p-3f;
    const float C3 =  0x1.556fc4p-2f;
    const float C2 = -0x1.ffffe2p-2f;

    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );

この時、tの値ごとの「t丸め誤差」と「poly \log(1+t)/t の差のt倍」が最終結果に与えるULPでみた影響の最大値は、図1のようになります。

図1: tの値ごとの最終丸め誤差以外の誤差が最終結果に与えるULPでみた影響の最大値のグラフ

なんとか±1.0ULPに収まっています。

最終的な実装のソースコード

// -------------------------------------------------------
//  Copyright 2023 lpha-z
//  https://lpha-z.hatenablog.com/entry/2023/09/03/231500
//  Distributed under the MIT license
//  https://opensource.org/license/mit/
// -------------------------------------------------------

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.0000000p+0f,
0x1.e286920p-1f,
0x1.c726fe0p-1f,
0x1.af35980p-1f,
0x1.99a95e0p-1f,
0x1.861a9e0p-1f,
0x1.746c640p-1f,
0x1.6435820p-1f,
0x1.5564f40p+0f,
0x1.47a8960p+0f,
0x1.3b1c5e0p+0f,
0x1.2f640a0p+0f,
0x1.24958c0p+0f,
0x1.1a813e0p+0f,
0x1.11180c0p+0f,
0x1.04d9b40p+0f,
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.e5b538p-5f,
+0x1.e2118ap-4f,
+0x1.5fb476p-3f,
+0x1.c8b0a8p-3f,
+0x1.166fecp-2f,
+0x1.45eeaap-2f,
+0x1.7383aap-2f,
-0x1.26c4fcp-2f,
-0x1.f96f70p-3f,
-0x1.a97736p-3f,
-0x1.5bd74ap-3f,
-0x1.118fbcp-3f,
-0x1.9387e8p-4f,
-0x1.08c23ep-4f,
-0x1.338588p-6f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = fma( mant, 0x1.fd9c88p-1f, 0x1.p+19f );
    if( mant >= 0x1.79c328p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    const float C4 = -0x1.fb1370p-3f;
    const float C3 =  0x1.556f14p-2f;
    const float C2 = -0x1.ffffe2p-2f;
    float poly = std::fma( std::fma( std::fma( C4, t, C3 ), t, C2 ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );

    return ret;
}

これらの改良により上記実装は先週示した実装と比べて命令数を増やさずに最大誤差を小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

となりました。  2 \le k \le 14 に対する  \sigma_k は適当にとったので、tの絶対値が小さくなるよう範囲の中心にとれば平均誤差はもう少し減らせるかもしれません。 ただ、0.5ULPを超える誤差が頻繁に発生する原因はstd::fma( expo, 0x1.62e430p-1f, logs )の丸めなので、係数を工夫しても0.5ULPを超える誤差が発生する頻度を下げることは難しそうです。 std::fma( expo, 0x1.62e430p-1f, std::fma( poly, t, logs ) )の順番で計算すれば0.5ULPを超える誤差が発生する頻度を3.3%まで下げることができますが、レイテンシが一命令分伸びます。

また、最大の誤差は-1.45944ULP/+1.47702ULPでした。 ULPで見た最大誤差を1.5ULPより有意に小さくすることは、命令数を増やさずには不可能だと思います。 アルゴリズムを大幅に見直せば回避できる可能性はありますが、おそらく命令数が増えます。

まとめ

logf(x)関数に含まれる誤差の原因を詳細に調べました。 idxfの計算に積和演算を使うことでテーブルの担当領域をずらし、さらに注意深く近似多項式を設計することにより、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。

おまけ:没になった戦略

 \sigma_1 として下位ビットに0が並ぶような浮動小数点数を選択すれば、tを計算するときの丸め誤差を0にすることができます。 しかし、 \sigma_1 として選べる浮動小数点数の選択肢が少なく、 \log s_1浮動小数点数に丸めるときの誤差を小さくすることができません。  \log s_1 の指数部はtの指数部より一つ大きいので、tを計算するときの丸め誤差0.5ULPよりも最終結果に与える影響を小さくしたければ、 \log s_1丸め誤差は0.25ULP以内に収めなければなりません。 これを満たすことができる \sigma_10x1.efp-1くらいしかなく、理想的な値0x1.e2p-1から離れすぎています。 よって、この方法は採用しませんでした。

*1:通常の浮動小数点数には指数部の制限があり非正規化数がありますが、その制限を無視して取り扱うということを意味します。