中点の正しい計算方法

二つの浮動小数点数 a, bについて、中点の数学的な値は \frac{a+b}2です。 以下の記事に触発されたので、この中点を浮動小数点数に正しく丸めた値を得る方法について考えます。

中点はどうやって計算するべきか - kashiの日記

当該記事によれば、この問題は意外と難しく、分岐なしに浮動小数点数演算だけで正しく求める方法はまだ知られていないようです。

何が難しいのか

足したらオーバーフローする

まず、もっとも単純に思いつくのは(a+b)*0.5と計算する方法でしょう。 しかしこれは、a+bの時点でオーバーフローする可能性があるので、中点を正しく求めることができません。

引いてもオーバーフローする

高精度に計算する方法として、a + (b-a)*0.5という方法が紹介されることもあります。 しかし、今度はb-aがオーバーフローする可能性があります。 それは、baの符号が異なり、それぞれ絶対値がすごく大きい場合です。 まぁそんな二数の中点を正確に求めたいということはなさそうですが、正しいアルゴリズムではないという点が気になります。

アンダーフローした際の丸め誤差

では、a*0.5 + b*0.5a + (b*0.5 - a*0.5)とするのはどうでしょうか。 これらはオーバーフローの可能性はありませんが、今度は逆にアンダーフローの可能性が出てきます。 通常、二進浮動小数点数の0.5倍は丸め誤差なしに正確に表せますが、結果がアンダーフローする場合は丸め誤差が生じる可能性があります。 最終結果に与える影響が同じ方向の丸め誤差が二つ以上の演算で発生した場合、正しくない答えが求まることがあります。

正しく求める方法

以下では、実装を検証しやすいようにfloatを使います(double用のmidpoint関数は入力のパターンが 2^{128}通りあって検証が困難です)。 floatの特徴はほぼ使っていないので、簡単にdoubleに置き換えることができます。

作戦

std::fma( a, 0.5f, b*0.5f )がほとんどの場合に正しい値を返すことに注目します。 これが正しくない丸めを行う必要条件は、b*0.5fの計算で丸め誤差が発生することです。 この丸め誤差に由来する真の中点とstd::fma( a, 0.5f, b*0.5f )の差の二倍hを算出することを考えます。

使用テクニック

double-doubleのように、実数を二つの浮動小数点数 x, yの組で表現することを考えます。 しかし、double-doubleのように x + yを表現するのではなく、 x + 0.5yを表現することを考えます。

この表現を使うと、任意の浮動小数点数 aについて、 0.5aが表現できます。  x=0, y=aとすればよいのでこれ自体は自明です。 重要なのは、以下のTwoProduct類似の手順で正規化することができることです。

std::pair<float, float> Half( float a ) {
    float hi = a * 0.5f;
    float lo = std::fma( hi, -2.0f, a );
    return { hi, lo };
}

この時、lo0.0fFLT_TRUE_MIN-FLT_TRUE_MINになります(丸めモードが何であってもそうなります)。

このように0.5倍を無誤差で扱う方法を導入します。

概略

まず、先のHalf関数を用いて、次の e, fを求めます。

auto [e, f] = Half(b)

つづいて、仮の中点値 m'を求めます。

float m_ = std::fma(a, 0.5f, e)

 m'と真の中点の差の二倍 hを計算します。

float h = f ? std::fma( m_, -2.0f, a ) + b : 0.0f;

最後に、 m'にそれを足し合わせ、最終結果を求めます。

float ret = std::fma( h, 0.5f, m_ );

もちろん、これでは分岐が発生しているのでダメです。 しかし、f0.0fFLT_TRUE_MIN-FLT_TRUE_MINしかありえないことを思い出します。 これを利用すると、以下のようにして分岐を消すことができます。

std::fesetround(FE_TOWARDZERO);
float half_or_zero = std::fma( f, f, -FLT_MIN/FLT_TRUE_MIN ) + FLT_MIN/FLT_TRUE_MIN;
std::fesetround(FE_TONEAREST);

float m_ = std::fma( a, 0.5f, e );
float h = std::fma( m_, -2.0f, a ) + b;

float ret = std::fma( h, half_or_zero, m_ );

これでも微妙にダメな点がいくつかあります。 一つ目は、b±0x1.fffffep-126fの時にうまくいかないことです。 この問題は、eを求めるときに零への丸めを使えば回避できます。 二つ目は、hがオーバーフローして無限大になると0.0fをかけてもNaNになってしまうことです。 この問題も、hを求めるときに零への丸めを使えば回避できます。

以上でおそらくすべての問題が解決されます。

具体的なソースコード

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();

    // Half
    std::fesetround(FE_TOWARDZERO);
    float e = b * 0.5f;
    float f = std::fma( e, -2.0f, b );

    float half_or_zero = std::fma( f, f, -0x1.p+23f ) + 0x1.p+23f;

    // FastTwoSum(?)
    std::fesetround(rm);
    float m_ = std::fma( a, 0.5f, e );
    std::fesetround(FE_TOWARDZERO);
    float h = std::fma( m_, -2.0f, a ) + b;

    std::fesetround(rm);
    float ret = std::fma( h, half_or_zero, m_ );

    return ret;
}

8命令で実現することができました。

FastTwoSumをさかさまに使うと0.0fが返ってくることを利用してfを見ずにできないかといろいろ試しましたが、うまくいく方法を見つけることができませんでした。

検証

次の分岐ありコード(中点はどうやって計算するべきか - kashiの日記で紹介されている方法)と結果が一致するかを確認します。

float TrueMidpointWithIf( float a, float b ) {
    if( std::abs(a) > 1.0f && std::abs(b) > 1.0f ) {
        float ha = a * 0.5f; // a*0.5fは丸め誤差なし
        float hb = b * 0.5f; // b*0.5fは丸め誤差なし
        float ret = ha + hb; // std::fma(a, 0.5f, b*0.5f)で良さそうだけれど、文献通りにしておく
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    } else {
        float sum = a + b; // 非正規化数になるときは丸め誤差なし
        float ret = sum * 0.5f; // sumが正規化数ならば丸め誤差なし
        return ret; // よって、誤差の生じる丸めは一回だけで、所望の方向に丸められる
    }
}

実験に使ったソースコードは以下の通りです。

#include <cstdint>
#include <cstring>
#include <iostream>
#include <random>

template<class To, class From>
To bit_cast( From from ) {
    To to;
    std::memcpy( &to, &from, sizeof to );
    return to;
}

float getRandom( std::mt19937_64& mt ) {
        std::uint64_t s = mt();
        switch( s % 4 ) {
        case 0: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) ); // tiny float
        case 1: return bit_cast<float, std::uint32_t>( ( s >> 32 & 0x80ffffff ) + 0x7e800000 ); // large float
        case 2: return bit_cast<float, std::uint32_t>( s >> 32 | 0x007ffff0 ); // large mantissa float
        case 3: return bit_cast<float, std::uint32_t>( s >> 32 ); // random float
        default: return 0.0f;
        }
}

int main( int argc, char* argv[] ) {
        static std::mt19937_64 mt(std::atoi(argv[1]));
        for( std::uint64_t i = 0; i < 100'000'000; ++i ) {
                float a = getRandom(mt);
                float b = getRandom(mt);
                if( a != a || b != b ) { continue; }
                if( bit_cast<std::uint32_t>(TrueMidpoint( a, b )) != bit_cast<std::uint32_t>(TrueMidpointWithIf( a, b )) ) {
                    std::cout << std::hexfloat << a << " " << b << " ";
                    std::cout << TrueMidpoint( a, b ) << " ";
                    std::cout << TrueMidpointWithIf( a, b ) << std::endl;
                }
        }
}

乱数のシード値としては1から16を使いました。 いずれの場合も出力がなかったため、TrueMidpointTrueMidpointWithIfと同じ動作をするようです。

ただし、このテストで確かめられていないコーナーケース入力について、以下の問題があることがわかっています。

  • TrueMidpoint( -0.0f, -0.0f )-0.0fを返してほしいが、0.0fを返してしまう
  • TrueMidpointの引数に無限大が含まれると、NaNが返ってきてしまう

これらは、two-component演算を使っている限り避けようがない問題です。

改良1

上で使った分岐を消す方法を応用して、TrueMidpointWithIf*1に含まれる分岐を消すことを考えます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = 0x1.p-149f / std::fma( b, 4.0f, 0x1.p-148f );
    std::fesetround(FE_TOWARDZERO);
    float one_or_half = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float half_or_one = 1.5f - one_or_half;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

このようにすると、除算という重たい演算が含まれてしまいますが、入力として無限大が来る場合や、入力の双方が-0.0fである場合も正しく取り扱うことができます。 無限大を普通の数に戻す方法は除算しかないので、この除算は不可避です。

改良2

入力として無限大が来ないならば、除算を取り除くことができます。

float TrueMidpoint( float a, float b ) {
    int rm = std::fegetround();
    float f = b * 0x1.p-149;
    std::fesetround(FE_TOWARDZERO);
    float half_or_one = std::fma( f, f, -0x1.fffffep+22f ) + 0x1.p+23f;
    std::fesetround(rm);
    float one_or_half = 1.5f - half_or_one;
    float sum_or_mid = std::fma( a, one_or_half, b * one_or_half );
    float ret = sum_or_mid * half_or_one;
    return ret;
}

7命令まで減りました。

まとめ

正しく中点を求める、分岐のないアルゴリズムを紹介しました。 命令数が多いので、実用的ではありません。

*1:正確にはfmaを使うように変更してbの絶対値の大小のみで分岐するようにしたバージョン