AVX-512の機能を使ったlogf(x)の実装(その1)

2023年6月に光成滋生さんがAVX-512特有の機能(通常の浮動小数点数演算以外の命令)を使用したlogf(x)のベクトル実装を公開しました(解説記事→AVX-512によるvpermpsを用いたlog(x)の実装)。 具体的には、指数部を取り出すvgetexpps命令、仮数部を取り出すvgetmantps命令、高速なテーブル参照を実現するvpermps命令の三つを使用しています。

これらの機能は非常に強力であり、正しく組み合わせることでlog関数を大幅に高速化することができます。 しかし、このfmathの実装は、AVX-512特有の機能を使わない部分に改善の余地がありそうです。 いろいろな部分を改善していたら、オリジナルとかなり異なるlogf(x)の実装が得られました。

fmathの実装のC++

fmathの実装はアセンブリコードを出力するPythonコードの形で公開されており、読み解くことが困難です(そもそもビルドする方法がわかりませんでした……)。 とりあえず読み解いた感じでは、以下のようなC++コードとほぼ等価な計算を行っているようです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数の翻案です(不正確かもしれません)。

float tbl1[16] = {
std::bit_cast<float>(0x3f783e10), // 32/33
std::bit_cast<float>(0x3f6a0ea1), // 32/35
std::bit_cast<float>(0x3f5d67c9), // 32/37
std::bit_cast<float>(0x3f520d21), // 32/39
std::bit_cast<float>(0x3f47ce0c), // 32/41
std::bit_cast<float>(0x3f3e82fa), // 32/43
std::bit_cast<float>(0x3f360b61), // 32/45
std::bit_cast<float>(0x3f2e4c41), // 32/47
std::bit_cast<float>(0x3f272f05), // 32/49
std::bit_cast<float>(0x3f20a0a1), // 32/51
std::bit_cast<float>(0x3f1a90e8), // 32/53
std::bit_cast<float>(0x3f14f209), // 32/55
std::bit_cast<float>(0x3f0fb824), // 32/57
std::bit_cast<float>(0x3f0ad8f3), // 32/59
std::bit_cast<float>(0x3f064b8a), // 32/61
std::bit_cast<float>(0x3f020821), // 32/63
};
float tbl2[16] = {
std::bit_cast<float>(0xbcfc14c8), // log(32/33)ではなくlog(tbl1[0])
std::bit_cast<float>(0xbdb78694), // 以下同様
std::bit_cast<float>(0xbe14aa96),
std::bit_cast<float>(0xbe4a92d4),
std::bit_cast<float>(0xbe7dc8c6),
std::bit_cast<float>(0xbe974716),
std::bit_cast<float>(0xbeae8ded),
std::bit_cast<float>(0xbec4d19d),
std::bit_cast<float>(0xbeda27bd),
std::bit_cast<float>(0xbeeea34f),
std::bit_cast<float>(0xbf012a95),
std::bit_cast<float>(0xbf0aa61f),
std::bit_cast<float>(0xbf13caf0),
std::bit_cast<float>(0xbf1c9f07),
std::bit_cast<float>(0xbf2527c4),
std::bit_cast<float>(0xbf2d6a01),
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fmath_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float xm1  = x - 1.0f;
    float abs  = std::bit_cast<float>(std::bit_cast<std::uint32_t>(xm1) & 0x7fffffff);
    if( abs < 0.02f ) {
        c = xm1;
        z = 0.0f;
    }

    float ret  = std::fma( std::fma( std::fma( std::fma( -0.25008487f, c, 0.3333955701f ), c, -0.49999999f ), c, 1.0f ), c, z );
    return ret;
}

改良

事前計算点の改良

この実装では -\log\left(1+\frac{k+0.5}{16}\right) (0\le k\lt16)を事前計算したテーブルを用意しています。 0.5ずれた点の値を持っているのはidxfを切り捨てで作っているのに合わせているからです。 しかしその代償としてx1.0f付近(logf(x)0.0f付近になって精度の要求が厳しくなる所)の時に精度が足りなくなっていて、それを解決するために分岐が発生しています。 幸い、アルゴリズムを大きく変える必要があるわけではないので、マスク付き演算で実現できますが、追加コストとして五命令かかっています。

この問題は、素直に \log\left(1+\frac{k}{16}\right) (0\le k\lt16)を事前計算したテーブルを持つことで解決可能です。 そのためにはidexfを最近接丸めで作る必要がありますが、これは「すごく大きな数を足すと仮数部が丸められる」というテクニック(これも光成滋生さんの解説で知りました→高速な倍精度指数関数expの実装)を使えばよいです。 具体的には、以下のようにします。

    float mant = vgetmantps( x ); // 1.0f <= mant < 2.0f で出てくる
    float idxf = mant + 0x1.p+19; // 最近接の1/16の倍数に丸める
    if( mant >= 0x1.f8p+0f ) { // 0x1.f8p+0f <= mant < 2.0f の時
            mant *= 0.5f; // mant が b = 1.0 に近くなるように mant をずらす
            expo += 1.0f; // expo もずらす
    }

二命令減らし、精度のための追加コストを三命令にすることができました。 ただし、このif文を外すと完全におかしな値を計算することになるので、「精度を犠牲にしても高速化したいときはこの範囲を消してください」というコードにはなっていません。

テーブル値の精度を上げる

上で行ったexpomantのずらしは、もっと積極的にやってよいです。 境界を0x1.f8p+0fではなく0x1.78p+0fとして、 8\le k \lt 15用の値を \log\left(1+\frac{k}{16}\right)ではなく \log\left(\frac12+\frac{k}{32}\right)にします。 これによりlogbの絶対値が小さくなって精度が上がるので、全体の精度が向上します。

近似多項式の改良

使われている係数が謎です。 見た感じでは |c| \lt 0.019に対する最良近似多項式っぽいですが、範囲をそのようにする理由がわかりません(範囲は |c| \lt \frac1{33} = 0.030303のはずです)。 何か間違いがあるように思えます。

上の最適化を行った場合のcの範囲である -\frac{0.5}{17} \lt c \lt \frac{0.5}{16}を前提にExcelを使って最良近似多項式を求めたところ、四次の係数は-0x1.fe9d24p-3f、三次の係数は0x1.557ee2p-2f、二次の係数は-0x1.00000cp-1f、とすれば良さそうだということがわかりました。 当該範囲での近似誤差は、絶対誤差 6.8\times10^{-10} = 1.46\times2^{-31}未満です。

テーブル値の精度が高くなるように事前計算点を選ぶ

よく考えると、事前計算に用いる点 \sigma_k \left(1+\frac{k}{16}\right)^{-1}浮動小数点数に丸めたものでなくてもよいはずです。 あまりにも離れた値を使うと |c|が大きくなって多項式近似誤差が大きくなってしまいますが、多少のずれは許容されます。 そこで、 \sigma_kとして \left(1+\frac{k}{16}\right)^{-1}にそれなりに近い浮動小数点数のうち \log \sigma_k浮動小数点数に丸めたときに誤差が十分小さくなるものを採用しましょう。

精度的に重要なのは、 \sigma_{15}の修正のようです。 ここで、 \sigma_{15}として適当な値は \frac{32}{31}付近には全くありません。 これは \left.\frac{d}{dx}\log(x)\right|_{x=\frac{32}{31}}=\frac{31}{32}でありこの値が二進浮動小数点数で正確に表せる数であることによるもので、この付近では隣の浮動小数点数におけるlogの返り値も下位ビット(=丸め誤差)がほぼ同じになります。 そのため、 \sigma_{15}0x1.084210p+0から0x1.084550p+0へとかなり大きくずらす必要がありました。

なお、この考え方は元のソースコードの精度を上げるif文がないバージョン(0.5ずれた点でテーブルを作るので命令数が少ない代わりに精度が低い方法)にも適用可能です。 tbl1[0] = 0x1.f081dcp-1ftbl2[0] = -0x1.f76c56p-6f(真の値は-0x1.f76c55ffa3p-6付近なので丸め誤差は0.00071ULP)を使うことで |\log(x)| \lt 2^{-6}での最大絶対誤差を 3.4\times10^{-9} = 1.79 \times2^{-29}未満にすることができます。

改良版のソースコード

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float invs_table[16] = {
0x1.000000p+0f, // 16/16
0x1.e1e1e2p-1f, // 16/17
0x1.c71c72p-1f, // 16/18
0x1.af286cp-1f, // 16/19
0x1.99999ap-1f, // 16/20
0x1.861862p-1f, // 16/21
0x1.745d18p-1f, // 16/22
0x1.642c86p-1f, // 16/23
0x1.555556p+0f, // 32/24
0x1.47ae14p+0f, // 32/25
0x1.3b13b2p+0f, // 32/26
0x1.2f684cp+0f, // 32/27
0x1.24924ap+0f, // 32/28
0x1.1a7b96p+0f, // 32/29
0x1.111112p+0f, // 32/30
0x1.084550p+0f, // ~32/31
};
float logs_table[16] = {
+0x0.000000p+0f,
+0x1.f0a30ap-5f,
+0x1.e27074p-4f,
+0x1.5ff306p-3f,
+0x1.c8ff7ap-3f,
+0x1.1675cap-2f,
+0x1.4618bap-2f,
+0x1.739d7ep-2f,
-0x1.269624p-2f,
-0x1.f991c4p-3f,
-0x1.a93ed8p-3f,
-0x1.5bf408p-3f,
-0x1.1178eep-3f,
-0x1.9335e4p-4f,
-0x1.08599ap-4f,
-0x1.047a88p-5f,
};

float my_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = mant + 0x1.p+19;
    if( mant >= 0x1.78p+0f ) {
        expo += 1.0f;
        mant *= 0.5f;
    }

    float invs = vpermps( idxf, invs_table );
    float t    = std::fma( mant, invs, -1.0f );
    float logs = vpermps( idxf, logs_table );

    float poly = std::fma( std::fma( std::fma( -0x1.fe9d24p-3f, t, 0x1.557ee2p-2f ), t, -0x1.00000cp-1f ), t, 1.0f );
    float ret  = std::fma( poly, t, std::fma( expo, 0x1.62e430p-1f, logs ) );
    return ret;
}

これらの改良により上記実装はfmathの実装よりも二命令減らしつつ最大誤差も小さくすることができました。 対数関数の定義域に属する2139095039個(0x7f7fffff個)の単精度浮動小数点数について誤差を評価してみると、

であり、最大の誤差は-1.925ULP/+2.122ULPでした。

この結果を見ると、ほとんどの場合は真値に最近接の浮動小数点数かその前後の浮動小数点数以外を返しており、それ以外の浮動小数点数を返す(≒1.5ULPを超える誤差が発生する)ことはほとんどないことがわかります。 誤差の原因を詳細に調べて各所の係数を工夫したところ、最大誤差1.5ULP未満を保証する実装を作ることに成功しました。 ただ、その説明は長くなるので、今回の記事はここまでにします。

おまけ:速度重視版のソースコードの改良版

 |\log(x)| \lt 2^{-6}での絶対誤差は 3.4\times10^{-9} = 1.79 \times2^{-29}未満となります。 その外側では-2.289ULP/+1.157ULPです。 まともに最適化していないので、誤差はもう少し小さくできそうです。

#include <bit>
#include <cmath>
#include <cstdint>

float vgetexpps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        return vgetexpps( x * 0x1.p+126f ) - 126.f;
    }
    return static_cast<float>(std::bit_cast<std::uint32_t>(x) << 1 >> 24) - 127.f;
}

float vgetmantps( float x ) {
    if( std::abs(x) < 0x1.p-126f ) {
        x *= 0x1.p+126f;
    }
    return std::bit_cast<float>(std::bit_cast<std::uint32_t>(x) & 0x00ffffff | std::bit_cast<std::uint32_t>(1.0f));
}

float vpermps( float x, float (&tbl)[16] ) {
    return tbl[std::bit_cast<std::uint32_t>(x) & 0xf];
}

float vpsrad( float x, std::size_t shamt ) {
    return std::bit_cast<float>(std::bit_cast<std::int32_t>(x) >> shamt);
}

// 以下の部分は、MITSUNARI Shigeoさんがmodified new BSD Licenseで公開
//(https://github.com/herumi/fmath/tree/master)しているfmathライブラリ
// に含まれるfmath_logf_avx512関数をC++に翻案したものをベースとし、
// テーブル値と多項式の係数を変更することで精度を改善したものです。

float tbl1[16] = {
0x1.f081dcp-1f,
0x1.d42b36p-1f,
0x1.badefep-1f,
0x1.a43a42p-1f,
0x1.8fa128p-1f,
0x1.7c91dcp-1f,
0x1.6bb3b6p-1f,
0x1.5c7e0ap-1f,
0x1.4e9076p-1f,
0x1.40c2dep-1f,
0x1.34bd30p-1f,
0x1.29e50ap-1f,
0x1.1f3f34p-1f,
0x1.16067ap-1f,
0x1.0c996cp-1f,
0x1.048122p-1f,
};
float tbl2[16] = {
-0x1.f76c56p-6f,
-0x1.6e9312p-4f,
-0x1.290ddap-3f,
-0x1.9489aep-3f,
-0x1.fb779ap-3f,
-0x1.2fc65cp-2f,
-0x1.5e3292p-2f,
-0x1.89f0fep-2f,
-0x1.b3b51ap-2f,
-0x1.ded9ccp-2f,
-0x1.02fbeep-1f,
-0x1.154a94p-1f,
-0x1.27ed54p-1f,
-0x1.38a234p-1f,
-0x1.4a4b10p-1f,
-0x1.59f5fap-1f,
};
float LOG2 = std::bit_cast<float>(0x3f317218);

float fast_logf( float x ) {
    float expo = vgetexpps( x );
    float mant = vgetmantps( x );
    float idxf = vpsrad( mant, 23 - 4 );

    float b    = vpermps( idxf, tbl1 );
    float c    = std::fma( mant, b, -1.0f );
    float logb = vpermps( idxf, tbl2 );

    float z    = std::fma( expo, LOG2, -logb );

    float ret  = std::fma( std::fma( std::fma( std::fma( -0x1.0049acp-2f, c, 0x1.557f3ep-2f ), c, -0x1.fffff8p-2f ), c, 1.0f ), c, z );
    return ret;
}