bit列のパターンマッチング分解を行う方法

先週の話題

bit列のマッチングを行う方法 - よーる

に引き続き、bit列操作の可読性を上げようという試みです。

パターンマッチを行い、特定のパターンに合致したと判明した時、残りの部分を(そのパターン特有の方法で)分解することは、特に関数型言語においては頻出の操作です。

bit列操作において頻出の操作であるかについてはやや疑問が残るところですが、使いたくなったので作ってみました。

まず、指定されたbitを抽出する、extract_bits関数を作ります。

#include <cstdint>
#include <utility>

constexpr uint64_t make_mask( const char* str, char C, uint64_t acc = 0 ) {
    return *str ? make_mask( str+1, C, acc<<1 | *str==C ) : acc;
}

template<uint64_t Mask, std::size_t... Indeces>
constexpr uint64_t extract_bits_impl( uint64_t bits, std::index_sequence<Indeces...> ) {
    uint64_t result = 0;
    int i = 0;

    ( (result |= Mask&1ull<<Indeces ? !!(bits&1ull<<Indeces)<<i++ : 0), ... );
    return result;
}

template<char C, class Str>
constexpr uint64_t extract_bits( Str pattern, uint64_t bits ) {
    return extract_bits_impl<make_mask( pattern(), C )>( bits, std::make_index_sequence<64>() );
}

#include <iostream>
int main() {
    uint64_t x;
    std::cin >> x;
    std::cout << extract_bits<'a'>( []{return"aaaabbbb";}, x ) << std::endl;
    std::cout << extract_bits<'b'>( []{return"aaaabbbb";}, x ) << std::endl;
    std::cout << extract_bits<'a'>( []{return"aaaaaaaabbbbbbbb";}, x ) << std::endl;
    std::cout << extract_bits<'b'>( []{return"aaaabbbbaaaabbbb";}, x ) << std::endl;
}

このコードは、実行時の効率にも注意を払っています。たとえば、for文で64回ループを作るのではなく、コンマ演算子の畳み込みを用いることで、ループアンローリングを強制します。 このようにすることで、コンパイラの定数伝搬最適化を助けることができます。

実際、このコードをclang++ -std=c++1z -O2 -S -masm=intel(バージョンは4.0.1、古い……)でコンパイルすると、以下のようなコードになります。

# extract_bits<'a'>( []{return"aaaabbbb";}, x )
mov   edx, dword ptr [rbp - 8]
shr   edx, 4
and   edx, 15

# extract_bits<'b'>( []{return"aaaabbbb";}, x )
mov   rdx, qword ptr [rbp - 8]
and   edx, 15

# extract_bits<'a'>( []{return"aaaaaaaabbbbbbbb";}, x )
mov   eax, dword ptr [rbp - 8]
movzx edx, ah  # NOREX

# extract_bits<'b'>( []{return"aaaabbbbaaaabbbb";}, x )
mov   rdx, qword ptr [rbp - 8]
mov   eax, edx
and   eax, 15
shr   rdx, 4
mov   ecx, edx
and   ecx, 16
or    rcx, rax
mov   eax, edx
and   eax, 32
or    rax, rcx
mov   ecx, edx
and   ecx, 64
or    rcx, rax
and   edx, 128
or    rdx, rcx

上三つのコードは最適コードになっています。最近のコンパイラはさすがの最適化性能です。最後のものは上位4bitの計算部分が最適とは言えなくなっていますが、抽出すべきbit部分が分離している時点で相当難しいことを要求しているため、これくらいは仕方がないでしょう(後半のor命令では不要に64bitレジスタを使っているところからもコンパイラが"見抜け"ていないことがうかがえます)。

パターンマッチを行うところまで実装すると、以下のようになります。

#include <cstdint>
#include <utility>
#include <array>

constexpr uint64_t make_mask( const char* str, char C, uint64_t acc = 0 ) {
    return *str ? make_mask( str+1, C, acc<<1 | *str==C ) : acc;
}

template<uint64_t Mask, std::size_t... Indeces>
constexpr uint64_t extract_bits_impl( uint64_t bits, std::index_sequence<Indeces...> ) {
    uint64_t result = 0;
    int i = 0;

    ( (result |= Mask&1ull<<Indeces ? !!(bits&1ull<<Indeces)<<i++ : 0), ... );
    return result;
}

template<char C, class Str>
constexpr uint64_t extract_bits( Str pattern, uint64_t bits ) {
    return extract_bits_impl<make_mask( pattern(), C )>( bits, std::make_index_sequence<64>() );
}



constexpr uint64_t make_mask( const char* str, uint64_t acc = 0 ) {
    return *str ? make_mask( str+1, acc<<1 | *str=='0' | *str=='1' ) : acc;
}

constexpr uint64_t make_bits( const char* str, uint64_t acc = 0 ) {
    return *str ? make_bits( str+1, acc<<1 | *str=='1' ) : acc;
}

template<class Str>
constexpr bool match( Str pattern, uint64_t test_bits ) {
    constexpr uint64_t mask = make_mask( pattern() );
    constexpr uint64_t bits = make_bits( pattern() );
    
    return (test_bits&mask) == bits;
}

template<char... Chars>
class Match {
    bool success;
    std::array<uint64_t, sizeof...(Chars)> decomp;
public:
    template<class Str>
    constexpr Match( Str pattern, uint64_t test_bits )
        : success( match( pattern, test_bits ) )
        , decomp { extract_bits<Chars>( pattern, test_bits )... }
        {}
        

    template<std::size_t N>
    constexpr auto get() const {
        if constexpr( N == 0 ) { return success; }
        else { return std::get<N-1>( decomp ); }
    }    
};

namespace std {
    template<char... Chars>
    class tuple_size<Match<Chars...>> : public std::integral_constant<std::size_t, sizeof...(Chars) + 1> {};
    
    template<std::size_t N, char... Chars>
    class tuple_element<N, Match<Chars...>> {
    public:
        using type = decltype( std::declval<Match<Chars...>>().template get<N>() );
    };
}

#include <iostream>
int main() {
    const uint64_t x = 0b00000000101000000000010110010011;
    auto [ADDI, rd, rs1, imm] = Match<'d','a','i'>( []{return"iiiiiiiiiiiiaaaaa000ddddd0010011";}, x );
    if( ADDI ) {
        std::cout << "ADDi x" << rd << ", x" << rs1 << ", " << imm << std::endl;
    }
}

[Wandbox]三へ( へ՞ਊ ՞)へ ハッハッ

ビット列のパターンマッチング分解なんて何に使うのかと思った方もいらっしゃったかもしれませんが、このように機械語のビット列をパースするのに使えます。 main関数では、RISC-VというISAのADDi x11, x0, 10という命令の機械語をパースしています(本来は符号拡張を実装しないといけないですが……)。