クイックソートのピボットは中央値でなく四分位数を選択したほうが高速

概要

クイックソートは一般に高速なソートアルゴリズムとして知られています。 アルゴリズム中にはピボットを選択する部分がありますが、ここでなるべく中央値に近い値を選択すると、総比較回数を少なくすることができます。 しかし、最近の高性能プロセッサは投機実行するため、比較回数ではなく分岐予測ミスの回数が性能に大きく影響します。 ピボットとして中央値ではなく四分位数を選択する場合、比較回数が23%増加するものの、分岐予測ミス回数を38%削減し、ソートにかかる時間を20%削減できることが分かりました。

はじめに

去年の末、kimiyukiさん(kimiyuki@うさぎ🐇 (@kimiyuki_u) | Twitter)に「クイックソートのピボットとして中央値ではない値を選んだほうが、分岐予測ミス率が減って高速なのでは」という話をいただきました。 その時私は、「分岐予測ミス率を減らせても、トータルの仕事量が増えてしまって意味がないのでは」と怪しんでいましたが、実際に実験してみると高速化することがわかりました。

以下では、高速化する原因を、モデルを用いて明らかにします。

実際に計測に用いたコード(ピボットをオラクル的に選択できるように改造したクイックソート)は以下です。

#include <array>
#include <utility>

std::array<double, 10000000> arr;

void q_sort( double* begin, double* end, double begin_v, double end_v ) {
        if( end - begin <= 1 ) { return; }

        const double pivot = begin_v*0.75 + end_v*0.25;

        double* p = begin;
        double* q = end;

        for( ; ; ) {
                while( p < q && *p < pivot ) ++p;
                do --q; while( p < q && *q > pivot );
                if( p >= q ) break;
                std::swap( *p, *q );
                ++p;
        }

        q_sort( begin, p, begin_v, pivot );
        q_sort( p, end, pivot, end_v );
}

#include <random>

int main() {
        std::mt19937 mt;
        std::uniform_real_distribution<double> dist(1.0, 2.0);
        for( auto& e : arr ) {
                e = dist(mt);
        }
        q_sort( arr.begin(), arr.end(), 1.0, 2.0 );
}

モデル化

理論

N要素に対するソートアルゴリズムは、N!通りある入力に対して1通りの出力を返すので、エントロピー \log_2(N!) ビット下げる必要があります。

クイックソートでは、ピボットとして中央値を選んだ場合、一回の分岐命令で1ビットのエントロピーを下げることができます。 一般に、ピボットとしてp:(1-p)に内分する点を選んだ場合、一回の分岐命令で H(p) = - p\log_2(p) - (1-p)\log_2(1-p)ビットのエントロピーを下げることができます。

よって、ピボットとしてp:(1-p)に内分する点を選んだ場合、 \log_2(N!) / H(p) 回の比較が少なくとも必要です*1

また、ソートを行うためには比較した後に分岐する必要がありますが、この分岐の飛び先には規則性がありません。 分岐予測器は、単に確率の高いほうに分岐するだろうと予測するしかなく、実際にそのように動作すると考えてよいです。 この場合の分岐予測成功確率は、 \max(p, 1-p)となります。

よって、ピボットとしてp:(1-p)に内分する点を選んだ場合、分岐予測ミスは、 \max(p, 1-p) \times \log_2(N!) / H(p)回発生することになります。

プロセッサのパラメータ

Intel社のSkylakeマイクロアーキテクチャでは、分岐予測ミスペナルティがおおよそ 20 cycle です。また、このコードの場合、ミスがなかった場合と比べて追加で 9 cycle 損します。これは、パイプラインがフラッシュされ、メモリアクセス(5 cycle)と浮動小数点数比較(4 cycle)がやり直しになるためです。

分岐予測ミスが発生しない場合、比較にかかる時間*2は 1 cycle です。

また、実測値を使うと、end - begin > 1の時の関数呼び出しのオーバーヘッドは、約 50 cycle でした。

クイックソートにかかる時間

ソートが完了するまでにかかる時間は、ピボット選択位置pをパラメータとして、

  • 分岐予測ミスペナルティ 29 cycle ×  \max(p, 1-p) \times \log_2(N!) / H(p)
  • 比較コスト 1 cycle ×  \log_2(N!) / H(p)
  • 関数呼び出しコスト 50 cycle ×  \log_2(N) / H(p)

の和としてモデル化できます。

その結果、p=0.2~0.25付近で最小値を取ることがわかります。

定性的な説明

ピボットとして中央値を選択すると、分岐予測ミス確率は50%です。分岐一回で減らせるエントロピーは、1ビットです。よって、分岐予測ミス一回当たりで減らせるエントロピーは、2ビットです。

一方、ピボットとして四分位数を選択した場合、分岐予測ミス確率は25%です。分岐一回で減らせるエントロピーは、0.811ビットです。よって、分岐予測ミス一回当たりで減らせるエントロピーは、3.2ビットです。

このように、減らせるエントロピーが小さくなる(=トータルの仕事量が多くなってしまう)効果よりも、分岐予測ミス確率が減る効果のほうがはるかに大きいです。

一般に、偏ったピボットをとるほど、一回の分岐予測ミスで減らせるエントロピーは大きくなります。

比較コストや関数呼び出しコストより分岐予測ミスペナルティのほうが十分大きい領域(0.2<p<0.8)においては、なるべく偏ったピボットを選択したほうが効率的です。

ただし、ピボットとしてあまりに偏った値を選択すると、比較コストや関数呼び出しコストが支配的となり、かえって低速になります。

おわりに

クイックソートのピボットは中央値を選ぶのが、比較回数が最小となるため望ましいというのが定説でした。 実際に中央値を選択するのは難しいですが、複数のピボット候補の中央値を選択することでなるべく中央値に近いと思われるピボットを選択する手法が一般に用いられています。

しかし、分岐予測ミスペナルティが非常に大きな近年の高性能プロセッサの場合、中央値に近い値をピボットとして選択するのは、かえって悪手である可能性があることがわかりました。 複数のピボット候補からピボットを選択する手法も、今後は見直す必要がありそうです。

*1:再帰呼び出し終端付近では無駄な比較が多発するため、約1.2倍の比較が実行されるようです

*2:正確にはスループット

sshでログインできなくなった問題と解決策

リモートマシンにsshログインしようとしたところ、パスフレーズの入力後に何も反応しなくなることがありました。

正しくない公開鍵を使った場合や、正しくないパスフレーズを入力した場合はログインできないと応答が返ってくるので、リモートマシンの側の問題ではなさそうです。

-vvvオプションをつけてログを見たところ、first channel unpausesという不思議なログが出ていました。

このメッセージでgoogle検索したところ、

PowerShellでRestart-Service LxssManagerを実行せよ」

という解決法を発見しました。

しかし、そもそもそのようなコマンドは存在せず、解決しませんでした。

Windows自体を再起動したところ、問題が解決しました。

64bit符号なし整数を素因数分解するのにかかる時間

64bit符号なし整数を素因数分解するため、事前に32bit符号なし整数で表せる素数をリストアップしておくとします。

方法1: 素数のリストをビットベクトルで保持し、試し割り

32bit符号なし整数で表せる奇数が素数であるかを、0, 1で表した列(ビットベクトル)を事前に作成し、それを素因数分解プログラムのバイナリに埋め込みます。 ビットベクトルのエントリ数は2Gi個で、1Byte=8bitであるとして256MiByteのバイナリが完成します。

32bit符号なし整数で表せる素数は203280221個(203M個)存在するので、この方法を使うと最悪で203M回の除算が発生します。

最近のCPUであっても除算器はパイプライン化されていないので、203M回×除算器の占有サイクル数くらいの時間がかかります。

この方法を使ったとき、手元のマシンでは、4.8秒くらいかかりました。ちょっと時間がかかりすぎな気がします。

方法2: 素数の逆元をファイルから読み込み、判定

m % p == 0という剰余演算はpが奇数の定数であれば、pの264における逆元を用いて最適化可能です。これを事前に計算しておく方式です。一つの素数当たり16Byteの情報が必要なため、バイナリに埋め込むことはできませんでした。

事前に計算した定数をファイル(バイナリファイルで3.03GiB、おそらくメモリ上にキャッシュ済み)から読み込んだ場合、素因数分解には5.2秒かかりました。一つの素数で割れるかを確認するのは1cycleで可能で非常に高速なのですが、読み込みに5.0秒くらいかかっているので、この方式をこれ以上改善することはできなさそうです。

mmapで読み込んだら、読み込み1秒、計算1秒くらいまで高速化しました。 素因数が小さいとき、読み込みが高速になるのもよい点です。

さすがに事前計算に時間がかかっているだけあって高速な手法です。

方法3: ポラードのρ法

以前に実装した高速なgcd実装を用い、フロイドの循環検出法を用いたオリジナルバージョンを実装しました。

素数であるかわかってない入力が来る場合、ポラードのρ法は素数であるかを確認するのに非常に不向きなので、別途素数判定アルゴリズムを併用することになります。

素数でないとわかっている入力に対し、何らかの因数を見つけるのにかかる時間は、(乱択アルゴリズムなのでぶれが大きいですが)平均5秒、最悪20秒程度でした。

素因数は232ほどの大きさであり、これの平方根(65536)くらいの回数の乱数生成で発見できると見積もるようですが、実際には数千万個の乱数を生成しないと素因数が発見できない例が多々ありました。

事前の計算は不要でプログラムサイズも小さいですが、素数のリストを事前に作成した試し割り法よりは低速だということがわかりました。

ABC152のE問題 Flatten を遅い方法で解いた

以下にはABC152のE問題 Flatten のネタバレが含まれます。


この問題は多倍長整数を使用できる言語であれば愚直解が通ります。Haskellで書かれた愚直解のコードは、例えばSubmission #9632460 - AtCoder Beginner Contest 152が簡潔で読みやすいです。

foldr lcm 1 aの部分は、 N回「 O\left(N\log A\right)ビットの整数と O\left(\log A\right)ビットの整数の最大公約数(高々 O\left(\log A\right)ビット)を求め、 O\left(\log A\right)ビットの整数をそれで割り、その結果と O\left(N\log A\right)ビットの整数を掛け合わせる」という計算を行うはずです。 O\left(N\log A\right)ビットの整数と \log Aビットの整数の最大公約数をユークリッドの互除法で求めると、多倍長除算は長除法(筆算)を用いるとして、時間計算量が O\left(\left(N + \log A\right)\left(\log A\right)^2\right)となります。

残りの部分は、取るに足らない計算量です。

制約は N=10^4, A=10^6なので、愚直解は雑な概算でΟの中が 4\times 10^9になります。時間制限は2秒であり、 5\times 10^9サイクル以上のCPU時間があるので、これは余裕で間に合います。実際、多くのHaskellでの愚直解の提出は、200msから300msくらいで動作しています。


C++標準には残念ながら多倍長整数がありません。 しかし、この問題に限っては位取り記数法を用いた多倍長整数を使う必要はありません。 掛け算せず、因数*1の列として保持すればよいです。

それを利用した解法の提出が以下です。

Submission #9637654 - AtCoder Beginner Contest 152

全然間に合っていません。実際、この解法の時間計算量は O\left(N^2 \left(\log A\right)^3\right)であり、愚直解より20倍ほど遅いことがわかります。つまり、3倍程度間に合いません。

手元で作った雑な最大ケースでは間に合ったため大丈夫だろうと思っていたのですが、よくよく考えると真の最大ケースは、「 N個の A未満の相異なる素数」です(テストケースの実行時間から推察すると、おそらくmax_02がこれです)。

そのケースを手元で動かした場合からの概算だと、ジャッジサーバーでは6秒くらいかかりそうだという感じだったので、いろいろな計算が大体合います。


そこで、自作の高速化したgcd関数を使えば間に合うのではないかと考えました。

以下は、最大公約数をもっと高速に求める(その4) - よーるで示した高速なgcd関数を、入力が \frac{2^{64}}5を超えない場合に特化したものです。

std::uint64_t gcd_impl( std::uint64_t n, std::uint64_t m ) {
    constexpr std::uint64_t K = 5;
    if( m == 0 ) { return n; }
    for( ; n / 64 < m; ) {
        std::uint64_t t = n - m;
        std::uint64_t s = n - m * K;
        bool q = t < m;
        bool p = t < m * K;
        n = q ? m : t;
        m = q ? t : m;
        if( m == 0 ) { return n; }
        n = p ? n : s;
    }
    return ::gcd_impl( m, n % m );
}
 
std::uint64_t gcd( std::uint64_t n, std::uint64_t m ) {
    return n > m ? ::gcd_impl( n, m ) : ::gcd_impl( m, n );
}

また、64 bit整数には A^3が収まるので、因数を3つまとめて保持することにより、高速化を図ります。これをおしすすめると愚直解になるわけです。

ところで、このgcd関数は、g++でコンパイルすると、mov命令を減らしたいのかわかりませんが低速なcmova命令にコンパイルされてしまいます(最大公約数をもっと高速に求める(その2)【cmova命令は遅い】 - よーる)。clang++だとそのようなことはないため、clang++で提出します。

Submission #9639909 - AtCoder Beginner Contest 152

1825msというぎりぎりですが、通すことができました。

かかった時間から推察すると、自作のgcd関数は一回当たり300サイクル程度で動作しているということがわかりました。

*1:素因数ではありません

Karatsuba乗算の最適化をした

Karatsuba乗算、あるいはToom–Cook2乗算とは、入力a0a1b0b1が与えられたとき、a0*b0a0*b1+a1*b0a1*b1の三つを出力する際、必要な乗算を四回から三回に減らせるアルゴリズムです。ただし、減算ができることを要求します。

長乗法(筆算と同様の手順による多倍長整数の乗算のこと)や数列の畳み込み演算は、愚直に行うとΘ(N2)の計算コストがかかります。

Karatsuba乗算のアルゴリズムを使って多倍長整数の乗算を行う場合、桁数が半分の整数の乗算を3回やることになります。 ここで出現した乗算にも再帰的にKaratsuba乗算のアルゴリズムを適用していけば、Ο(N1.585)の計算コストとなります。 ここで、指数の1.585は、正確には \log3/\log2です。

具体的なアルゴリズムは、非常に単純です。

  1. a0*b0a1*b1を計算する。乗算がここで二回必要です。
  2. p = a0-a1q = b0-b1を計算する。減算がここで必要です。
  3. p*qを計算する。乗算がここでもう一回必要です。
  4. a0*b0 + a1*b1 - p*qを計算する。これがa0*b1 + a0*b1になっています。

このように計算結果を再利用することで、乗算の回数を減らすことに成功しています。

以下は、この再帰的にKaratsuba乗算を適用したプログラムを最適化する話です。

問題

Convolution (mod 1,000,000,007)

剰余類環*1 \mathbb{Z}/1000000007\mathbb{Z}における畳み込みを行う問題です。

入力長Nは219であり、最も高速な解法は畳み込み定理と高速フーリエ変換FFT*2を用いた方法でしょう。

しかし、高速フーリエ変換を用いた手法がKaratsuba乗算より高速になるのはN~5000程度とされており、その百倍程度の入力長であればKaratsuba乗算で強引に押し切ることも可能です*3

実装

私が書いたもの

#include <cstdio>
#include <cstdint>
#include <cinttypes>

static constexpr int64_t MOD = 1'000'000'007;
static constexpr size_t N_MAX = 1<<19;
static constexpr size_t K = 8;

template<size_t N>
void karatsuba( const int64_t* x, const int64_t* y, int64_t* z ) {
    static int64_t t[N], p[N/2], q[N/2];

    for( size_t i = 0; i < N/2; ++i ) {
        p[i] = (x[i] - x[i+N/2] + MOD) % MOD;
        q[i] = (y[i] - y[i+N/2] + MOD) % MOD;
    }

    karatsuba<N/2>( x, y, z );
    karatsuba<N/2>( x + N/2, y + N/2, z + N );
    karatsuba<N/2>( p, q, t );

    for( size_t i = 0; i < N/2; ++i ) {
        int64_t a = z[i];
        int64_t b = z[i+N/2];
        int64_t c = z[i+N];
        int64_t d = z[i+N+N/2];
        int64_t e = t[i];
        int64_t f = t[i+N/2];

        z[i+N/2] = a + b + c - e;
        z[i+N]   = d + b + c - f;
    }
}

template<>
void karatsuba<K>( const int64_t* x, const int64_t* y, int64_t* z ) {
    uint64_t tmp[K*2] = {};

    for( size_t i = 0; i < K; ++i ) {
        for( size_t j = 0; j < K; ++j ) {
            tmp[i+j] += x[i] * y[j];
        }
    }

    for( size_t i = 0; i < K*2; ++i ) {
        z[i] = tmp[i] % MOD;
    }
}

int main() {
    static int64_t a[N_MAX] = {}, b[N_MAX] = {}, answer[N_MAX*2];

    size_t n, m;
    scanf( "%zd%zd", &n, &m );

    for( size_t i = 0; i < n; ++i ) {
        scanf( "%" SCNd64, &a[i] );
    }

    for( size_t i = 0; i < m; ++i ) {
        scanf( "%" SCNd64, &b[i] );
    }

    karatsuba<N_MAX>( a, b, answer );
    for( size_t i = 0; i < n+m-1; ++i ) {
        if( i ) { putchar(' '); }
        printf( "%" PRId64, (answer[i]%MOD+MOD)%MOD );
    }
    puts("");

    return 0;
}

工夫その1:静的に領域を確保する

Karatsuba乗算に必要な作業領域は、Nが決まれば静的に決定可能で、動的確保は不要です。

一見再帰呼び出し的なコードになっていますが、常にNが半分ずつになっていく再帰呼び出しです。 よって、テンプレートで記述すれば毎回異なる関数を呼び出す通常の関数呼び出しとなり、関数内static変数として作業領域を確保することができます。

工夫その2:定数回ループ以外の条件分岐を取り除く

条件分岐は遅くなる原因であるので、取り除きます。

0埋めされた計算とはいえ、常に最大ケースと同じだけの計算量が必要になっているので、小さいケースでも時間がかかります。

ジャッジサーバーに負荷をかけてすみません。

工夫その3:小さい部分では再帰しない

Karatsuba乗算は、乗算の回数がオーダーレベルで減る代わりに加減算の回数が増えるという性質があります。

そのため、入力長Nが非常に小さい部分ではΘ(N2)の愚直な方法が高速になります。

定数での剰余演算は、乗算命令二回に最適化されます。これは、a%b == a-a/b*bであること、定数での除算は逆数の乗算に最適化できる*4ことによります。 このことを考慮に入れた、切り替え入力長K以下で愚直計算に切り替えた場合の乗算命令数は以下のようになります。

K 剰余演算に必要な乗算命令数 愚直計算に必要な乗算回数 合計
1 ~319×8 319 104.6G
2 ~318×16 318×4 77.5G
4 ~317×32 317×16 62.0G
8 ~316×64 316×64 55.1G
16 ~315×128 315×256 55.1G
32 ~314×256 314×1024 61.2G

この結果を見ると、乗算回数としてはK=8またはK=16がよさそうということになります。 加減算の数のことまで考えると、N=16の時にKaratsuba法を適用してN=8とするのは乗算回数が減らないのに加減算だけが増える悪手であり、K=16とするのが適切であると結論付けることができます。

しかし、実験の結果、入力長が8のところで方法を切り替えるのが最も高速という結論になりました。

入力長が16のところで切り替えたほうがやや計算量が少なくなるはずですが、コンパイラのインライン展開の限界の影響か、実測では遅くなりました。

切り替え入力長が8であっても、愚直計算のロード命令の数を最小化しようとするとレジスタが20個ほど必要で、x86の持つ汎用レジスタの数16個を超えてしまいます。 切り替え入力長を16にしてしまうとさらにロード命令が増えてしまいます。 これが切り替え入力長16で遅くなる原因と思われます。

工夫その4:なるべく剰余演算を行わない

定数での剰余演算は、乗算命令二回に最適化されますが、それでも重い演算であることには変わりがなく、なるべく行いたくありません。

プログラムの入力は0以上MOD未満ですが、途中で加減算をするとこの制約が満たされなくなります。 そのため、桁あふれを防ぐため、少なくとも乗算の前までには0以上MOD未満の制約を満たすようにしておく必要があります。 karatsuba<K>の最初で剰余をとっても良いですが、もともとの入力は剰余をとらなくてもいいことを考えると、pqを計算した後に行うのが最も回数が少なくなります(回数の差はほとんどはありません)。

乗算結果はMOD*MOD未満であり、これを8個足すくらいでは*5uint64_tは桁あふれしないため、乗算ごとに剰余をとる必要はありません。 karatsuba<K>が終わるときに剰余をとればよさそうです。

z[i+N/2] = a + b + c - e;の部分で四つ足すため、再帰から一段戻るたびに絶対値が4倍程度になりますが、再帰深度はlog(219/8)=16なので、絶対値は232倍にしかなりません。 MODの232倍はint64_tに収まるため、ここでも剰余をとる必要はありません。 最後の出力する部分で剰余をとればよいでしょう。

工夫その5:データのコピー回数を減らす

出力すべき答えのうち前1/4と後1/4は、再帰呼び出しで得られる答えの前1/2や後1/2と正確に一致します。 ここのコピーにかかる時間も実行時間にかなり影響してくるので、再帰関数にポインタを渡して直に書き込んでもらうことでコピーコストを減らします。

ポインタで渡すのはエイリアス解析が困難になりそうですが、最近のコンパイラだと意外とやってくれるようです(?)

更なる最適化

うさぎさんが最適化したもの

p[i] = (x[i] - x[i+N/2] + MOD) % MOD;

としていた部分を

p[i] = x[i] - x[i+N/2];
p[i] += MOD * (p[i] < 0);

と変形することで演算強度を下げる最適化です。x[i]x[i+N/2]0以上MOD未満であることが保証されているので、この変形は無問題です。

p[i] += MOD * (p[i] < 0);の部分は、手元のg++-8g++-9を使った場合、以下のどれを用いてもほとんど同じ速度となりました。

  • p[i] = p[i] < 0 ? p[i] + MOD : p[i];
  • p[i] += p[i] < 0 ? MOD : 0;
  • p[i] <0 && (p[i] += MOD);
  • if( p[i] < 0 ) p[i] += MOD;

いずれもcmovs命令を用いたコードにコンパイルされます。

ジャッジサーバー*6では、if文を用いたものは低速だったようです。テストケースごとの傾向から、これは分岐予測ミスによるものだと考えられます。 ここの分岐は予測不可能な分岐であるため、条件分岐命令にコンパイルされると速度が低下するのでしょう。

ただし、wandboxを使った調査では、g++-6以降やclang++の場合はif文を用いても速度低下はありませんでした。

これを用いた場合、手元のCPU(Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz)では(user時間で)2.0秒程度で実行が完了します。 このCPUのターボブースト時クロック周波数は3.50GHzであり、このプログラムに含まれる乗算41.3億回を行うためにはそれだけで最低でも1秒以上かかります*7。 Karatsuba乗算にはこれ以外にも加減算が多く含まれることを考えると、かなり限界に近い性能が出せていることになるでしょう。

*1:環とは、おおざっぱに言って加算・減算・乗算が行えるような数の集合です。

*2:ここでは、複素数環に限らず任意の代数環で行える広義のフーリエ変換を言っています。剰余類環で行うフーリエ変換は、数論変換(NTT)とも呼ばれます。

*3:NlogNとN1.585に5000と219を代入してみると、差はおおよそ10倍程度であるということがわかります

*4:浮動小数点数の除算を逆数乗算に変形した場合はわずかな誤差が発生しますが、整数の場合は正確に求める手法があります

*5:16個でも大丈夫です

*6:g++ 7.4.0を使っているらしいです

*7:乗算命令のスループットは1であり、このCPUでは1秒間には35億回しかできません