128bit除算のやり方

五月祭お疲れさまでした。私はクイックソートのメモリアクセス系列の可視化の展示をしていましたが、楽しんでいただけたでしょうか……。

C言語には多倍長整数のような便利なものがありません。unsigned long longを使えば64bit整数までは取り扱えますが、それ以上になると面倒です。 ちょっと64bitをはみ出すくらいの整数を扱いたい、128bit整数型があれば……。という場合、gccやclangなら__uint128_tという型が使えます。これはその名前の通り、128bitの整数型です。 ここではx86_64を前提として、64bitマシンでどのように128bitの演算を行っているか見てみます。

足し算・引き算

x86にはadc命令・sbb命令が存在します。これはキャリー付きの足し算・ボロー付きの引き算を行う命令です。 下位64bitの計算を行った後上位64bitの計算を行うとき、下位64bitの計算での繰り上がり・繰り下がり(キャリー・ボロー)を考慮に入れた計算を行うことができます。 このように多倍長演算用の命令が用意されているため、足し算・引き算は簡単に行えます。

掛け算

x86にはmul命令が存在します。これはimul命令と異なり、64bit整数同士の積を128bit整数として求めることができる命令です。 128bitのレジスタはありませんが、結果は上位64bitと下位64bitを異なる二つのレジスタに分けて書き込まれます。

あとは通常の掛け算を3回行い、筆算のように結果を足し合わせれば128bit整数同士の積(の下位128bit分)を求めることができます。

割り算

割り算はそう簡単にはいきません。gccの場合、__udivti3というライブラリ関数呼び出しにコンパイルします。 この関数を逆アセンブルした結果、__udivti3関数は以下のような254Byteのコードになっているようです。 機械語の雰囲気はコンパイラの出力っぽいコードで、プログラマが書いた雰囲気を感じ取ることができませんでした。

   0x0000000000400560 <+0>:     mov    %rcx,%r8
   0x0000000000400563 <+3>:     mov    %rdx,%r9
   0x0000000000400566 <+6>:     mov    %rdx,%r10
   0x0000000000400569 <+9>:     test   %r8,%r8
   0x000000000040056c <+12>:    mov    %rdx,%rcx
   0x000000000040056f <+15>:    jne    0x4005a8 <__udivti3+72>
   0x0000000000400571 <+17>:    cmp    %rsi,%rdx
   0x0000000000400574 <+20>:    ja     0x400628 <__udivti3+200>
   0x000000000040057a <+26>:    test   %rdx,%rdx
   0x000000000040057d <+29>:    jne    0x40058c <__udivti3+44>
   0x000000000040057f <+31>:    mov    $0x1,%eax
   0x0000000000400584 <+36>:    xor    %edx,%edx
   0x0000000000400586 <+38>:    div    %r9
   0x0000000000400589 <+41>:    mov    %rax,%rcx
   0x000000000040058c <+44>:    mov    %rsi,%rax
   0x000000000040058f <+47>:    xor    %edx,%edx
   0x0000000000400591 <+49>:    div    %rcx
   0x0000000000400594 <+52>:    mov    %rax,%rsi
   0x0000000000400597 <+55>:    mov    %rdi,%rax
   0x000000000040059a <+58>:    div    %rcx
   0x000000000040059d <+61>:    mov    %rsi,%rdx
   0x00000000004005a0 <+64>:    retq
   0x00000000004005a1 <+65>:    nopl   0x0(%rax)
   0x00000000004005a8 <+72>:    cmp    %rsi,%r8
   0x00000000004005ab <+75>:    ja     0x400620 <__udivti3+192>
   0x00000000004005ad <+77>:    bsr    %r8,%rax
   0x00000000004005b1 <+81>:    xor    $0x3f,%rax
   0x00000000004005b5 <+85>:    test   %eax,%eax
   0x00000000004005b7 <+87>:    mov    %eax,%r11d
   0x00000000004005ba <+90>:    je     0x400638 <__udivti3+216>
   0x00000000004005bc <+92>:    mov    %eax,%ecx
   0x00000000004005be <+94>:    mov    $0x40,%edx
   0x00000000004005c3 <+99>:    shl    %cl,%r8
   0x00000000004005c6 <+102>:   movslq %eax,%rcx
   0x00000000004005c9 <+105>:   sub    %rcx,%rdx
   0x00000000004005cc <+108>:   mov    %edx,%ecx
   0x00000000004005ce <+110>:   shr    %cl,%r9
   0x00000000004005d1 <+113>:   mov    %eax,%ecx
   0x00000000004005d3 <+115>:   or     %r9,%r8
   0x00000000004005d6 <+118>:   shl    %cl,%r10
   0x00000000004005d9 <+121>:   mov    %rsi,%r9
   0x00000000004005dc <+124>:   mov    %edx,%ecx
   0x00000000004005de <+126>:   shr    %cl,%r9
   0x00000000004005e1 <+129>:   mov    %eax,%ecx
   0x00000000004005e3 <+131>:   mov    %rdi,%rax
   0x00000000004005e6 <+134>:   shl    %cl,%rsi
   0x00000000004005e9 <+137>:   mov    %edx,%ecx
   0x00000000004005eb <+139>:   mov    %r9,%rdx
   0x00000000004005ee <+142>:   shr    %cl,%rax
   0x00000000004005f1 <+145>:   or     %rax,%rsi
   0x00000000004005f4 <+148>:   mov    %rsi,%rax
   0x00000000004005f7 <+151>:   div    %r8
   0x00000000004005fa <+154>:   mov    %rdx,%r9
   0x00000000004005fd <+157>:   mov    %rax,%rsi
   0x0000000000400600 <+160>:   mul    %r10
   0x0000000000400603 <+163>:   cmp    %rdx,%r9
   0x0000000000400606 <+166>:   jb     0x400618 <__udivti3+184>
   0x0000000000400608 <+168>:   mov    %r11d,%ecx
   0x000000000040060b <+171>:   shl    %cl,%rdi
   0x000000000040060e <+174>:   cmp    %rax,%rdi
   0x0000000000400611 <+177>:   jae    0x400658 <__udivti3+248>
   0x0000000000400613 <+179>:   cmp    %rdx,%r9
   0x0000000000400616 <+182>:   jne    0x400658 <__udivti3+248>
   0x0000000000400618 <+184>:   lea    -0x1(%rsi),%rax
   0x000000000040061c <+188>:   xor    %edx,%edx
   0x000000000040061e <+190>:   retq
   0x000000000040061f <+191>:   nop
   0x0000000000400620 <+192>:   xor    %edx,%edx
   0x0000000000400622 <+194>:   xor    %eax,%eax
   0x0000000000400624 <+196>:   retq
   0x0000000000400625 <+197>:   nopl   (%rax)
   0x0000000000400628 <+200>:   mov    %rdi,%rax
   0x000000000040062b <+203>:   mov    %rsi,%rdx
   0x000000000040062e <+206>:   div    %r9
   0x0000000000400631 <+209>:   xor    %edx,%edx
   0x0000000000400633 <+211>:   retq
   0x0000000000400634 <+212>:   nopl   0x0(%rax)
   0x0000000000400638 <+216>:   cmp    %rsi,%r8
   0x000000000040063b <+219>:   jb     0x40064a <__udivti3+234>
   0x000000000040063d <+221>:   xor    %edx,%edx
   0x000000000040063f <+223>:   xor    %eax,%eax
   0x0000000000400641 <+225>:   cmp    %rdi,%r9
   0x0000000000400644 <+228>:   ja     0x4005a0 <__udivti3+64>
   0x000000000040064a <+234>:   xor    %edx,%edx
   0x000000000040064c <+236>:   mov    $0x1,%eax
   0x0000000000400651 <+241>:   retq
   0x0000000000400652 <+242>:   nopw   0x0(%rax,%rax,1)
   0x0000000000400658 <+248>:   mov    %rsi,%rax
   0x000000000040065b <+251>:   xor    %edx,%edx
   0x000000000040065d <+253>:   retq

これを少しは読みやすい疑似コードにしてみたのが以下です。 以下では、上位64bitがh、下位64bitがlであるような128bit変数の意味で(y#l)という表記を使います。変数はすべてレジスタサイズと同じ64bit整数です。

__udivti3( __uint128_t x, __uint128_t y ) {
  (xh#xl) = x;
  (yh#yl) = y;
  if( yh == 0 ) {
    if( xh >= yl ) {
      // 商が64bitに収まらない時
      if (yl == 0) {
        // 零除算例外を発生させる
        rcx = 1 / yl;
      }
      // 長除法
      uint64_t qh = xh / yl;
      uint64_t r = xh % yl;
      uint64_t ql = (r#xl) / yl;
      return (qh#ql);
    } else {
      // 商が64bitに収まる
      return (xh#xl) / yl;
    }
  } else {
    if( xh >= yh ) {
      // yh != 0なので、count_trailing_zerosを計算するbsr命令が使える
      S = 63^count_trailing_zeros( yh );
      if( S != 0 ) {
        // yの上から64bitをyhに集める
        (yh#yl) = (yh#yl) << S;
        // 商の上界値を求める
        q = ((xh#xl) >> (64-S)) / yh;
        r = ((xh#xl) >> (64-S)) / yh;
        (mh#ml) = q * yl;
        if( mh >= r ) {
          // xl << Sは、(xh#xl)>>(64-S)ではみ出した部分
          if( q >= xl << S || mh != r ) { return q; }
        }
        return q - 1;
      } else {
        // (yh#yl)が2の127乗以上の時
        if (xh <= yh && xl < yl) {
          // (xh#xl) < (yh#yl) の時(xh == yh && xl < yl)
          return 0;
        }
        return 1;
      }
    } else {
      // (xh#xl) < (yh#yl) の時(xh < yh) 
      return 0;
    }
  }
}

x86には、被除数が128bit、除数が64bitの割り算命令divが存在します。 そのため、除数が64bitに収まっている場合、単純に長除法(筆算)を行うだけで終わりです。 しかし、除数が64bitに収まらない場合、そのような計算を行う命令がなく、工夫が必要です。

概念的には、次のような計算を行っています。 まずxとyを両方ともSビットシフトして、yの最上位ビットが1になるようにします。こうすることで、除数の値を切り捨てではありますが二進有効数字64桁で表すことができます。 この状態で割り算をすると、除数を切り捨てたことから、商の上界値qが求まります(被除数も切り捨てているので議論がややこしいですが、切り捨てられた部分の大きさは1未満なので割り算の結果に影響を及ぼしません)。 この時、真の商としてあり得るのはqq-1になります(真の除数は切り捨てた部分の影響は2^-63未満であることによります)。 後はylxlの情報を用いてそのどちらかになるかを決定しています(この部分はよく確認していません)。