64bit乗算の上位部分の計算

x86RISC-Vなどの多くの機械語命令セットには、64ビット整数同士の積(128ビット整数になる)の上位64ビットを求める命令が存在します。 しかし、C言語C++の言語標準には128ビット整数が存在しないため、これを簡単に得ることができません。 gccなどのコンパイラには非標準ながら__uint128_tのようなものがあり、これを使うことでそういった命令を呼び出すことができますが、標準の範囲で計算するにはどのようにすればよいかを考えます。

以下の議論はuint64_tint64_tがない処理系でも大丈夫(64ビット以上の整数、つまりint_least64_tなどで十分)なはずですが、確かめられる処理系を持っていないので確かめていません。

MULHUU

uint64_tuint64_tの積の上位64ビットを求めます。 つまり、以下のコードと同じことを実現します。

uint64_t MULHUU( uint64_t x, uint64_t y ) {
    return (__uint128_t)x * y >> 64;
}
# x86-64
MULHUU(unsigned long, unsigned long):
        mov     rax, rdi
        mul     rsi
        mov     rax, rdx
        ret
# RISC-V
MULHUU(unsigned long, unsigned long):
        mulhu   a0,a0,a1
        ret

これを求めるにはまず、積がuint64_tに収まる範囲にするために、xyをそれぞれ上位32ビットと下位32ビットに分割します。

uint64_t xl = x & 0xffffffff;
uint64_t xh = x >> 32;
uint64_t yl = y & 0xffffffff;
uint64_t yh = y >> 32;

その後、長乗算(筆算)を行えば答えを求めることができます。 ただし、32ビット整数同士の積(64ビット整数になっている)を二つ足すとオーバーフローして計算が合わなくなることに注意します。 これを回避するために、32ビット整数同士の積(64ビット整数)をもう一度上位32ビットと下位32ビットに分割します。

uint64_t z0 = xl * yl;
uint64_t z1 = xl * yh;
uint64_t z2 = xh * yl;
uint64_t z3 = xh * yh;

uint64_t z0h = z0 >> 32;
uint64_t z1l = z1 & 0xffffffff;
uint64_t z1h = z1 >> 32;
uint64_t z2l = z2 & 0xffffffff;
uint64_t z2h = z2 >> 32;

uint64_t  s = z0h + z1l + z2l;
return z3 + z1h + z2h + ( s >> 32 );

z0lは必要ありません。以下のような筆算の図を見てみるとわかりやすいと思います。

                                   [.....z0h......][.....z0l......]
                   [.....z1h......][.....z1l......]
                   [.....z2h......][.....z2l......]
   [..............z3..............]
----------------------------------------------------------------------
                                 [.......s........][.....z0l......]
                   [.....z1h......]
                   [.....z2h......]
   [..............z3..............]
   <----------必要な範囲----------->

MULHSS

int64_tint64_tの積の上位64ビットを求めます。 つまり、以下のコードと同じことを実現します。

// 負数の表現が二の補数で符号付き整数に対する右シフトが算術シフトになる処理系
int64_t MULHUU( int64_t x, int64_t y ) {
    return (__int128_t)x * y >> 64;
}
# x86-64
MULHSS(long, long):
        mov     rax, rdi
        imul    rsi
        mov     rax, rdx
        ret
# RISC-V
MULHSS(long, long):
        mulh    a0,a0,a1
        ret

128ビット整数に符号拡張してから計算するという手順もありますが、乗算の回数が増えてしまいます。 ここではうまく符号付きの乗算を使って4回の乗算で行う方法を示します。 ちなみに、絶対値をMULHUU関数に入力してから適切な符号をつける、という手順は誤りです(-1×1の上位は-1になるべきですが、0になってしまいます)。

まず、先ほどと同様に上位32ビットと下位32ビットに分割しますが、その際に上位は符号付きにしておきます。

uint64_t xl = x & 0xffffffff;
int64_t xh = x >> 32;
uint64_t yl = y & 0xffffffff;
int64_t yh = y >> 32;

ただし、符号付き整数のシフトが算術シフトになる処理系を仮定しました。 シフト部分を以下のように書けば、そうでない処理系にも対応しつつ、clangやgccで算術シフトに最適化されるそうです(C言語における移植性のある算術シフトの記述方法 - よーるのコメントに書かれていました🙏)。

int64_t xh = x < 0 ? -(-(x+1) >> 32) - 1 : x >> 32;

また、負の数の表現が二の補数である処理系を仮定しました。 以下のように書いたほうが移植性が高いですね。

uint64_t xl = (uint64_t)y & 0xffffffff;

さて、この後はやはり長乗算(筆算)を行えば結果を求めることができます。 xhの絶対値は231以下、ylの絶対値は232未満なので、その積の絶対値は263未満になり、64ビット符号付き整数に収まることに注意します。 乗算結果のビット表現は符号付きで計算しても符号なしで計算しても同じですが、int64_tに格納するときに符号付き整数オーバーフローの未定義動作が起こらないようにする必要があります。 最も簡単なのは、掛け算を符号付きで行うことでしょう。

uint64_t z0 = xl * yl;
int64_t z1 = (int64_t)xl * yh;
int64_t z2 = xh * (int64_t)yl;
int64_t z3 = xh * yh;

後は先ほどと同じで大丈夫です。 ただし、結果を符号付き整数で返す場合、途中計算を符号なし整数で行ってしまうと、符号付き整数に戻す段階で符号付き整数オーバーフローの未定義動作が生じる可能性があります。 途中計算を符号付き整数で行ってもオーバーフローは起きないため、途中計算も符号付き整数で行うのが簡単です。

int64_t z0h = z0 >> 32;
int64_t z1l = z1 & 0xffffffff;
int64_t z1h = z1 >> 32;
int64_t z2l = z2 & 0xffffffff;
int64_t z2h = z2 >> 32;

int64_t  s = z0h + z1l + z2l;
return z3 + z1h + z2h + ( s >> 32 );

MULHSU

int64_tuint64_tの積の上位64ビットを求めます。 つまり、以下のコードと同じことを実現します。

int64_t MULHSU( int64_t x, uint64_t y ) {
    return (__int128_t)x * y >> 64;
}
# x86-64
MULHSU(long, unsigned long):
        mov     rax, rsi
        mul     rdi
        sar     rdi, 63
        imul    rdi, rsi
        lea     rax, [rdi + rdx]
        ret
# RISC-V
MULHSU(long, unsigned long):
        mulhsu    a0,a0,a1
        ret

あまり出てこなさそうな形式の乗算で、実際x86にはこれに相当する命令がありません(x86の場合は、符号拡張して128ビット符号なし整数×64ビット符号なし整数の乗算に持ち込んでいます)。 専用命令があるRISC-Vのマニュアルを読んでみると、多倍長の符号付き整数の乗算を実装するために使えるそうです。 たしかに先ほどのxh * yl部分は符号付き整数と符号なし整数の乗算になっていました。

この形式の場合は、以下のように上位32ビットと下位32ビットに分割します。

uint64_t xl = x & 0xffffffff;
int64_t xh = x >> 32;
uint64_t yl = y & 0xffffffff;
uint64_t yh = y >> 32;

長乗算(筆算)部分は以下のようになります。

uint64_t z0 = xl * yl;
uint64_t z1 = xl * yh;
int64_t z2 = xh * (int64_t)yl;
int64_t z3 = xh * (int64_t)yh;

後は先ほどと同じようにやれば大丈夫です。

int64_t z0h = z0 >> 32;
int64_t z1l = z1 & 0xffffffff;
int64_t z1h = z1 >> 32;
int64_t z2l = z2 & 0xffffffff;
int64_t z2h = z2 >> 32;

int64_t  s = z0h + z1l + z2l;
return z3 + z1h + z2h + ( s >> 32 );

統一する

以下のように書けば、RISC-Vの三種類の命令を実装するときのコードを短くすることができます。 ※C++20なので負の数の表現は二の補数です!

template<class LHS, class RHS>
constexpr std::uint64_t mulhxx64( std::uint64_t lhs, std::uint64_t rhs ) noexcept {

    const std::uint64_t lhs0 = lhs & 0xffffffff;
    const LHS           lhs1 = std::bit_cast<LHS>(lhs) >> 32;
    const std::uint64_t rhs0 = rhs & 0xffffffff;
    const RHS           rhs1 = std::bit_cast<RHS>(rhs) >> 32;

    const std::uint64_t d = lhs1 * rhs1;
    const std::uint64_t c = lhs1 * static_cast<LHS>(rhs0) >> 32;
    const std::uint64_t b = static_cast<RHS>(lhs0) * rhs1 >> 32;
    const std::uint64_t a = ( lhs0 * rhs0 >> 32 ) + ( lhs0 * rhs1 & 0xffffffff ) + ( lhs1 * rhs0 & 0xffffffff );

    return d + c + b + ( a >> 32 );
}
constexpr std::uint64_t MULHUU( std::uint64_t lhs, std::uint64_t rhs ) noexcept {
    return mulhxx64<std::uint64_t, std::uint64_t>( lhs, rhs );
}
constexpr std::uint64_t MULHSU( std::uint64_t lhs, std::uint64_t rhs ) noexcept {
    return mulhxx64<std::int64_t, std::uint64_t>( lhs , rhs );
}
constexpr std::uint64_t MULHSS( std::uint64_t lhs, std::uint64_t rhs ) noexcept {
    return mulhxx64<std::int64_t, std::int64_t>( lhs, rhs );
}

このコードは見た目上6つの乗算を含みますが、clangやgccは4つの乗算に最適化します。 Nビット整数同士の乗算結果の下位Nビットは、その乗算が符号付きか符号なしかで変化しないためです(これはmod 2Nで考えてみれば明らかです)。

まとめ

64ビット整数同士の乗算結果の上位64ビットを求める、コンパイラの独自機能に頼らない方法を示しました。 また、符号付きと符号なしの場合のどちらにも対応できる、テンプレートを使った実装も示しました。