Karatsuba乗算、あるいはToom–Cook2乗算とは、入力a0
、a1
、b0
、b1
が与えられたとき、a0*b0
、a0*b1+a1*b0
、a1*b1
の三つを出力する際、必要な乗算を四回から三回に減らせるアルゴリズムです。ただし、減算ができることを要求します。
長乗法(筆算と同様の手順による多倍長整数の乗算のこと)や数列の畳み込み演算は、愚直に行うとΘ(N2)の計算コストがかかります。
Karatsuba乗算のアルゴリズムを使って多倍長整数の乗算を行う場合、桁数が半分の整数の乗算を3回やることになります。 ここで出現した乗算にも再帰的にKaratsuba乗算のアルゴリズムを適用していけば、Ο(N1.585)の計算コストとなります。 ここで、指数の1.585は、正確にはです。
具体的なアルゴリズムは、非常に単純です。
a0*b0
、a1*b1
を計算する。乗算がここで二回必要です。p = a0-a1
、q = b0-b1
を計算する。減算がここで必要です。p*q
を計算する。乗算がここでもう一回必要です。a0*b0 + a1*b1 - p*q
を計算する。これがa0*b1 + a0*b1
になっています。
このように計算結果を再利用することで、乗算の回数を減らすことに成功しています。
以下は、この再帰的にKaratsuba乗算を適用したプログラムを最適化する話です。
問題
Convolution (mod 1,000,000,007)
剰余類環*1における畳み込みを行う問題です。
入力長Nは219であり、最も高速な解法は畳み込み定理と高速フーリエ変換(FFT)*2を用いた方法でしょう。
しかし、高速フーリエ変換を用いた手法がKaratsuba乗算より高速になるのはN~5000程度とされており、その百倍程度の入力長であればKaratsuba乗算で強引に押し切ることも可能です*3。
実装
#include <cstdio> #include <cstdint> #include <cinttypes> static constexpr int64_t MOD = 1'000'000'007; static constexpr size_t N_MAX = 1<<19; static constexpr size_t K = 8; template<size_t N> void karatsuba( const int64_t* x, const int64_t* y, int64_t* z ) { static int64_t t[N], p[N/2], q[N/2]; for( size_t i = 0; i < N/2; ++i ) { p[i] = (x[i] - x[i+N/2] + MOD) % MOD; q[i] = (y[i] - y[i+N/2] + MOD) % MOD; } karatsuba<N/2>( x, y, z ); karatsuba<N/2>( x + N/2, y + N/2, z + N ); karatsuba<N/2>( p, q, t ); for( size_t i = 0; i < N/2; ++i ) { int64_t a = z[i]; int64_t b = z[i+N/2]; int64_t c = z[i+N]; int64_t d = z[i+N+N/2]; int64_t e = t[i]; int64_t f = t[i+N/2]; z[i+N/2] = a + b + c - e; z[i+N] = d + b + c - f; } } template<> void karatsuba<K>( const int64_t* x, const int64_t* y, int64_t* z ) { uint64_t tmp[K*2] = {}; for( size_t i = 0; i < K; ++i ) { for( size_t j = 0; j < K; ++j ) { tmp[i+j] += x[i] * y[j]; } } for( size_t i = 0; i < K*2; ++i ) { z[i] = tmp[i] % MOD; } } int main() { static int64_t a[N_MAX] = {}, b[N_MAX] = {}, answer[N_MAX*2]; size_t n, m; scanf( "%zd%zd", &n, &m ); for( size_t i = 0; i < n; ++i ) { scanf( "%" SCNd64, &a[i] ); } for( size_t i = 0; i < m; ++i ) { scanf( "%" SCNd64, &b[i] ); } karatsuba<N_MAX>( a, b, answer ); for( size_t i = 0; i < n+m-1; ++i ) { if( i ) { putchar(' '); } printf( "%" PRId64, (answer[i]%MOD+MOD)%MOD ); } puts(""); return 0; }
工夫その1:静的に領域を確保する
Karatsuba乗算に必要な作業領域は、N
が決まれば静的に決定可能で、動的確保は不要です。
一見再帰呼び出し的なコードになっていますが、常にN
が半分ずつになっていく再帰呼び出しです。
よって、テンプレートで記述すれば毎回異なる関数を呼び出す通常の関数呼び出しとなり、関数内static変数として作業領域を確保することができます。
工夫その2:定数回ループ以外の条件分岐を取り除く
条件分岐は遅くなる原因であるので、取り除きます。
0埋めされた計算とはいえ、常に最大ケースと同じだけの計算量が必要になっているので、小さいケースでも時間がかかります。
ジャッジサーバーに負荷をかけてすみません。
工夫その3:小さい部分では再帰しない
Karatsuba乗算は、乗算の回数がオーダーレベルで減る代わりに加減算の回数が増えるという性質があります。
そのため、入力長N
が非常に小さい部分ではΘ(N2)の愚直な方法が高速になります。
定数での剰余演算は、乗算命令二回に最適化されます。これは、a%b == a-a/b*b
であること、定数での除算は逆数の乗算に最適化できる*4ことによります。
このことを考慮に入れた、切り替え入力長K
以下で愚直計算に切り替えた場合の乗算命令数は以下のようになります。
K |
剰余演算に必要な乗算命令数 | 愚直計算に必要な乗算回数 | 合計 |
---|---|---|---|
1 |
~319×8 | 319 | 104.6G |
2 |
~318×16 | 318×4 | 77.5G |
4 |
~317×32 | 317×16 | 62.0G |
8 |
~316×64 | 316×64 | 55.1G |
16 |
~315×128 | 315×256 | 55.1G |
32 |
~314×256 | 314×1024 | 61.2G |
この結果を見ると、乗算回数としてはK=8
またはK=16
がよさそうということになります。
加減算の数のことまで考えると、N=16
の時にKaratsuba法を適用してN=8
とするのは乗算回数が減らないのに加減算だけが増える悪手であり、K=16
とするのが適切であると結論付けることができます。
しかし、実験の結果、入力長が8のところで方法を切り替えるのが最も高速という結論になりました。
入力長が16のところで切り替えたほうがやや計算量が少なくなるはずですが、コンパイラのインライン展開の限界の影響か、実測では遅くなりました。
切り替え入力長が8であっても、愚直計算のロード命令の数を最小化しようとするとレジスタが20個ほど必要で、x86の持つ汎用レジスタの数16個を超えてしまいます。 切り替え入力長を16にしてしまうとさらにロード命令が増えてしまいます。 これが切り替え入力長16で遅くなる原因と思われます。
工夫その4:なるべく剰余演算を行わない
定数での剰余演算は、乗算命令二回に最適化されますが、それでも重い演算であることには変わりがなく、なるべく行いたくありません。
プログラムの入力は0
以上MOD
未満ですが、途中で加減算をするとこの制約が満たされなくなります。
そのため、桁あふれを防ぐため、少なくとも乗算の前までには0
以上MOD
未満の制約を満たすようにしておく必要があります。
karatsuba<K>
の最初で剰余をとっても良いですが、もともとの入力は剰余をとらなくてもいいことを考えると、p
やq
を計算した後に行うのが最も回数が少なくなります(回数の差はほとんどはありません)。
乗算結果はMOD*MOD
未満であり、これを8個足すくらいでは*5uint64_t
は桁あふれしないため、乗算ごとに剰余をとる必要はありません。
karatsuba<K>
が終わるときに剰余をとればよさそうです。
z[i+N/2] = a + b + c - e;
の部分で四つ足すため、再帰から一段戻るたびに絶対値が4倍程度になりますが、再帰深度はlog(219/8)=16なので、絶対値は232倍にしかなりません。
MOD
の232倍はint64_t
に収まるため、ここでも剰余をとる必要はありません。
最後の出力する部分で剰余をとればよいでしょう。
工夫その5:データのコピー回数を減らす
出力すべき答えのうち前1/4と後1/4は、再帰呼び出しで得られる答えの前1/2や後1/2と正確に一致します。 ここのコピーにかかる時間も実行時間にかなり影響してくるので、再帰関数にポインタを渡して直に書き込んでもらうことでコピーコストを減らします。
ポインタで渡すのはエイリアス解析が困難になりそうですが、最近のコンパイラだと意外とやってくれるようです(?)
更なる最適化
p[i] = (x[i] - x[i+N/2] + MOD) % MOD;
としていた部分を
p[i] = x[i] - x[i+N/2]; p[i] += MOD * (p[i] < 0);
と変形することで演算強度を下げる最適化です。x[i]
とx[i+N/2]
は0
以上MOD
未満であることが保証されているので、この変形は無問題です。
p[i] += MOD * (p[i] < 0);
の部分は、手元のg++-8
やg++-9
を使った場合、以下のどれを用いてもほとんど同じ速度となりました。
p[i] = p[i] < 0 ? p[i] + MOD : p[i];
p[i] += p[i] < 0 ? MOD : 0;
p[i] <0 && (p[i] += MOD);
if( p[i] < 0 ) p[i] += MOD;
いずれもcmovs
命令を用いたコードにコンパイルされます。
ジャッジサーバー*6では、if文を用いたものは低速だったようです。テストケースごとの傾向から、これは分岐予測ミスによるものだと考えられます。 ここの分岐は予測不可能な分岐であるため、条件分岐命令にコンパイルされると速度が低下するのでしょう。
ただし、wandboxを使った調査では、g++-6以降やclang++の場合はif
文を用いても速度低下はありませんでした。
これを用いた場合、手元のCPU(Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz)では(user時間で)2.0秒程度で実行が完了します。 このCPUのターボブースト時クロック周波数は3.50GHzであり、このプログラムに含まれる乗算41.3億回を行うためにはそれだけで最低でも1秒以上かかります*7。 Karatsuba乗算にはこれ以外にも加減算が多く含まれることを考えると、かなり限界に近い性能が出せていることになるでしょう。