C++
sort
OpenMP
Grand_Central_Dispatch

マージソートをマルチスレッドで爆速にする件

動機

昔からマージソートを自前で書いてマルチスレッド化したら std::sort より早くなる可能性があるんじゃね?って思ってたので、やってみました。
ソースは最後にのせてあります。

マージソートとは

マージソートにもいくつか種類があるのですが、この記事では次のようなアルゴリズムを想定しています。

入力が

84736251

だったとします。2つずつ組みにして小さい順にします。

48 37 26 15

2 つずつの組みをマージしていきます。

3478 1256

同様に全体が一つになるまでマージしていきます。

12345678

できあがり、みたいなやつです。

長所

  • アルゴリズムが非常に簡単
  • 時間計算量が O(NlogN)(std::sortとほぼ一緒)
  • 再帰なしで書けるので安全。

欠点

  • 入力と同じ大きさの作業バッファを必要とする。 (作業バッファを使わないインプリも可能なんですが、私の腕では速度的に満足するものが書けませんでした)

結果

以下の計測の入力データは、要素数をNとしたら、0 から N-1 までの数をシャッフルしたデータです。

環境1

とりあえず手元の慣れてる環境でやってみました。

  • MacPro Late2013 3.7 GHz Quad-Core Intel Xeon E5 (4コア)
  • clang++ -Os でコンパイル
  • マルチスレッドは Grand Central Dispatcher (GCD) で実現
アルゴリズム 64M要素 128M要素 256M要素
std::sort 4,999ms 10,431ms 21,631ms
MergeSort(シングル) 5,005ms 10,308ms 21,483ms
MergeSort(マルチ) 3,072ms 6,107ms 12,588ms

シングルスレッドのスピード、MergeSort と std::sort でほぼ一緒です。
GCD は 1024 要素がソート済みになったところから使っています。(ソース参照)
倍速とまではいかないですが、結構早くなってます。

環境2

なんか調子よさそうなので、36 vCPU の AWS の EC2 インスタンスで OpenMP やってみました。

  • AWS c4.8xlarge インスタンス(36 vCPU)
  • シングルスレッドは g++ -Os, マルチスレッドは g++ -Os -fopenmp でコンパイル
  • マルチスレッドは OpenMP で実現
アルゴリズム 64M要素 128M要素 256M要素
std::sort 6,100ms 12,577ms 26,118ms
MergeSort(シングル) 9,258ms 19,111ms 39,548ms
MergeSort(マルチ) 1,515ms 2,948ms 6,082ms

MergeSort のマルチ、std::sort の4倍速です。なんかちょっと嬉しいですw

でもって、OpenMP、めっちゃ簡単です。事実上

#pragma omp parallel for

って書いただけです。もうちょっといろいろやれる可能性もありそうです。

ちなみに、ec2 の費用、0.329 時間で $0.66 でした。

ソース

#include    <iostream>
#include    <algorithm>
#include    <vector>
#include    <map>
#include    <chrono>
#include    <random>

using       namespace std;
using       namespace chrono;

////////////////////////////////////////////////////////////////

template<typename T>    vector< T >
Data( T p ) {
    vector< T > v( p );
    for ( auto i = 0; i < p; i++ ) v[ i ] = p - i - 1;
    return v;
}
template<typename T>    vector< T >
ShuffledData( T p ) {

    mt19937 wMT( (random_device())() );

    auto v = Data( p );
    shuffle( v.begin(), v.end(), wMT );
    return v;
}
template<typename T>    vector< T >
RandomData( T p ) {
    mt19937 wMT( (random_device())() );
    uniform_int_distribution< T >  wUID( 0, p );

    vector< T > v( p );
    for ( auto i = 0; i < p; i++ ) v[ i ] = wUID( wMT );
    return v;
}

template<typename T>    void
Check( const vector< T >& p ) {
    for ( auto i = 0; i < p.size(); i++ ) if ( p[ i ] != i ) {
        cerr << "CHECK ERROR: " << i << ':' << p[ i ] << endl;
        throw 1;
    }
}

struct
StopWatch {
    string                      tag;
    time_point< system_clock >  m;

    StopWatch( string pTag ) {
        tag = pTag;
        m = system_clock::now();
    }
    ~
    StopWatch() {
        auto w = duration_cast< milliseconds >( system_clock::now() - m ).count();
        cerr << tag << ':' << w << endl;
    }
};


////////////////////////////////////////////////////////////////    Switch Buffer
template< typename T >  void
MergeOutplace( T* l, T* r, T* e, T* d ) {
    auto    wL = l;
    auto    wR = r;
    do {
        if ( wL == r ) {
            while ( wR < e ) *d++ = *wR++;
            break;
        }
        if ( wR == e ) {
            while ( wL < r ) *d++ = *wL++;
            break;
        }
        *d++ = *wL < *wR ? *wL++ : *wR++;
    } while ( true );
}

#undef  USE_GCD
#ifdef  USE_GCD
#include "dispatch/dispatch.h"
#endif

template< typename T >  void
MergeSort( T* l, T* r ) {

#ifdef  USE_GCD
cerr << "Using GCD" << endl;
#else
#ifdef _OPENMP
    cerr << "Using OpenMP" << endl;
#else
    cerr << "No parallelism" << endl;
#endif
#endif
    auto wSize = r - l;
    if ( wSize == 0 ) return;
    if ( wSize == 1 ) return;
    for ( auto i = 0; i < wSize / 2 * 2; i += 2 ) if ( l[ i + 1 ] < l[ i ] ) swap( l[ i ], l[ i + 1 ] );
    if ( wSize == 2 ) return;

    auto wSorted = 2;
    auto wInW = false;
    vector< T > w( wSize );

#ifdef  USE_GCD
    auto    q = dispatch_get_global_queue( DISPATCH_QUEUE_PRIORITY_DEFAULT, 0 );
    auto    g = dispatch_group_create();
#endif
    do {
        auto wToBe = wSorted * 2;
        auto wBorder = wSize / wToBe * wToBe;

        T* wS = wInW ? &w[ 0 ] : l;
        T* wD = wInW ? l : &w[ 0 ];
        wInW = !wInW;

#ifdef  USE_GCD
        if ( wSorted < 1024 ) {
            for ( auto i = 0; i < wBorder; i += wToBe ) {
                auto w = wS + i;
                MergeOutplace( w, w + wSorted, w + wToBe, wD + i );
            }
            if ( wSize - wBorder > wSorted ) {
                auto w = wS + wBorder;
                MergeOutplace( w, w + wSorted, r, wD + wBorder );
            } else {
                for ( auto i = wBorder; i < wSize; i++ ) wD[ i ] = wS[ i ];
            }
        } else {
            for ( auto i = 0; i < wBorder; i += wToBe ) {
                dispatch_group_async( g, q, ^{
                    auto w = wS + i;
                    MergeOutplace( w, w + wSorted, w + wToBe, wD + i );
                } );
            }
            if ( wSize - wBorder > wSorted ) {
                dispatch_group_async( g, q, ^{
                    auto w = wS + wBorder;
                    MergeOutplace( w, w + wSorted, r, wD + wBorder );
                } );
            } else {
                for ( auto i = wBorder; i < wSize; i++ ) wD[ i ] = wS[ i ];
            }
            dispatch_group_wait( g, DISPATCH_TIME_FOREVER );
        }
#else
#ifdef _OPENMP
    #pragma omp parallel for
#endif
        for ( auto i = 0; i < wBorder; i += wToBe ) {
            auto    w = wS + i;
            MergeOutplace( w, w + wSorted, w + wToBe, wD + i );
        }
        if ( wSize - wBorder > wSorted ) {
            auto    w = wS + wBorder;
            MergeOutplace( w, w + wSorted, r, wD + wBorder );
        } else {
            for ( auto i = wBorder; i < wSize; i++ ) wD[ i ] = wS[ i ];
        }
#endif
        wSorted *= 2;
    } while ( wSorted < wSize );
    if ( wInW ) for ( auto i = 0; i < wSize; i++ ) l[ i ] = w[ i ];
}

int
main() {
    for ( unsigned wSize = 256 * 256; wSize <= 127 * 256 * 256 * 256; wSize *= 2 ) {
        auto    wTag = to_string( wSize );
        cerr << wTag << endl;
        {   auto wShuffledData = ShuffledData( wSize );
            cerr << "Start" << endl;
            {   auto wData = wShuffledData;
                {   StopWatch   wSW( wTag + " ShuffledData MergeSort" );
                    MergeSort( &wData[ 0 ], &wData[ wData.size() ] );
                }
                Check( wData );
            }
            {   auto wData = wShuffledData;
                {   StopWatch   wSW( wTag + " ShuffledData std::sort" );
                    sort( &wData[ 0 ], &wData[ wData.size() ] );
                }
                Check( wData );
            }
        }
        {   auto wRandomData = RandomData( wSize );
            cerr << "Start" << endl;
            auto wData1 = wRandomData;
            {   StopWatch   wSW( wTag + " RandomData MergeSort" );
                MergeSort( &wData1[ 0 ], &wData1[ wData1.size() ] );
            }
            auto wData2 = wRandomData;
            {   StopWatch   wSW( wTag + " RandomData std::sort" );
                sort( &wData2[ 0 ], &wData2[ wData2.size() ] );
            }
            for ( auto i = 0; i < wRandomData.size(); i++ ) if ( wData1[ i ] != wData2[ i ] ) throw 0;
        }
    }
    return 0;
}