LoginSignup
5
3

More than 3 years have passed since last update.

ソーティングネットワークをコンパイル時に生成してみた

Last updated at Posted at 2019-05-29

ソーティングネットワークについて

wikipediaをどうぞ

動機

自分は組み込み系に触れていて、最近数値データをソートする需要1に出会った。データ数は非常に少なく(N<50)、固定である。

そう、ソーティングネットワークの出番だ。Qiitaで「std::sortより特定条件で高速な8入力ソーティングネットワークを作ってみる」を見つけたので、こちらを参考に組んでいく。

ただ、設定で要素数を変更することがあるのでコンパイル時に生成できるようにする。

実装

ソーティングネットワークの生成には、元記事で紹介されていたソルバが使っていたBose-Nelsonアルゴリズムを使用した。ただ、そのアルゴリズムの詳細がわからなかったのでこちらを見ながらコンパイル時に生成できるように変更。Compare-Swapに関しては元記事で掘り下げられていたのでそちらをパクリスペクトしつつ算術型以外では普通のSWAPになるようにした。

#include <tuple>
#include <type_traits>
#include <iterator>

namespace srtnw
{

template <std::size_t i_, std::size_t j_>
struct SwapPair
{
    static constexpr std::size_t i = i_, j = j_;
};

template <class RandomAccessIterator>
inline constexpr void compair_swap(RandomAccessIterator v, std::size_t i, std::size_t j)
{
    if constexpr (std::is_arithmetic_v<typename std::iterator_traits<decltype(v)>::value_type>) {
        auto v_i = std::move(v[i]);
        auto v_j = std::move(v[j]);
        v[i] = v_i < v_j ? v_i : v_j;
        v[j] = v_i < v_j ? v_j : v_i;
    } else {
        using std::swap;
        if (v[j] < v[i]) swap(v[i], v[j]);
    }
}

template <class SNW>
class SortingNetwork
{
    template <class IT, class... Args>
    inline static constexpr void sort_impl(IT first, std::tuple<Args...>*)
    {
        (..., compair_swap(first, Args::i, Args::j));
    }

public:
    template <class RandomAccessIterator>
    inline static constexpr void sort(RandomAccessIterator first) { sort_impl(first, static_cast<typename SNW::type*>(nullptr)); }
};

template <std::size_t n>
struct BoseNelson
{
private:
    template <class... Args>
    using unwrap_tuple_cat_t = decltype(std::tuple_cat(std::declval<typename Args::type>()...));

    template <std::size_t i, /* value of first element in sequence 1 */
              std::size_t x, /* length of sequence 1 */
              std::size_t j, /* value of first element in sequence 2 */
              std::size_t y> /* length of sequence 2 */
    struct Pbracket
    {
        static constexpr std::size_t a = x / 2;
        static constexpr std::size_t b = (x & 1) ? (y / 2) : ((y + 1) / 2);

        using type = unwrap_tuple_cat_t<Pbracket<i, a, j, b>,
                                        Pbracket<i + a, x - a, j + b, y - b>,
                                        Pbracket<i + a, x - a, j, b>>;
    };
    template <std::size_t i, std::size_t j>
    struct Pbracket<i, 1, j, 1>
    {
        using type = std::tuple<SwapPair<i, j>>;
    };
    template <std::size_t i, std::size_t j>
    struct Pbracket<i, 1, j, 2>
    {
        using type = std::tuple<SwapPair<i, j + 1>, SwapPair<i, j>>;
    };
    template <std::size_t i, std::size_t j>
    struct Pbracket<i, 2, j, 1>
    {
        using type = std::tuple<SwapPair<i, j>, SwapPair<i + 1, j>>;
    };

    template <std::size_t i, std::size_t m, class /* for detection idiom */ = void>
    struct Pstar
    {
        static constexpr std::size_t a = m / 2;
        using type = unwrap_tuple_cat_t<Pstar<i, a>,
                                        Pstar<i + a, m - a>,
                                        Pbracket<i, a, i + a, m - a>>;
    };
    template <std::size_t i, std::size_t m>
    struct Pstar<i, m, std::enable_if_t<(m < 2)>>
    {
        using type = std::tuple<>;
    };

public:
    using type = typename Pstar<0, n>::type;
};
template <std::size_t n>
using BoseNelson_t = typename BoseNelson<n>::type;

static_assert(std::is_same_v<typename BoseNelson<0>::type,
                             std::tuple<>>);
static_assert(std::is_same_v<typename BoseNelson<1>::type,
                             std::tuple<>>);
static_assert(std::is_same_v<typename BoseNelson<2>::type,
                             std::tuple<SwapPair<0, 1>>>);
static_assert(std::is_same_v<typename BoseNelson<8>::type,
                             std::tuple<SwapPair<0, 1>,
                                        SwapPair<2, 3>,
                                        SwapPair<0, 2>,
                                        SwapPair<1, 3>,
                                        SwapPair<1, 2>,
                                        SwapPair<4, 5>,
                                        SwapPair<6, 7>,
                                        SwapPair<4, 6>,
                                        SwapPair<5, 7>,
                                        SwapPair<5, 6>,
                                        SwapPair<0, 4>,
                                        SwapPair<1, 5>,
                                        SwapPair<1, 4>,
                                        SwapPair<2, 6>,
                                        SwapPair<3, 7>,
                                        SwapPair<3, 6>,
                                        SwapPair<2, 4>,
                                        SwapPair<3, 5>,
                                        SwapPair<3, 4>>>);
} // namespace srtnw

本当にソーティングネットワークが実装できているかをチェックするために性能評価をする。
比較方法は元記事と同じで
比較するのは今回作ったソート、元記事のソート、std::sort
ソート本体を最適化で消されないようにstd::is_sortedを挟んだ

#include <iostream>
#include <algorithm>
#include <random>
#include <vector>
#include <cstdint>

void network_sort_8n(int32_t* v)
{
#define SWAP(a, b) srtnw::compair_swap(v, a, b)
    SWAP(0, 1);
    SWAP(2, 3);
    SWAP(4, 5);
    SWAP(6, 7);
    SWAP(0, 2);
    SWAP(1, 3);
    SWAP(4, 6);
    SWAP(5, 7);
    SWAP(1, 2);
    SWAP(5, 6);
    SWAP(0, 4);
    SWAP(3, 7);
    SWAP(1, 5);
    SWAP(2, 6);
    SWAP(1, 4);
    SWAP(3, 6);
    SWAP(2, 4);
    SWAP(3, 5);
    SWAP(3, 4);
#undef SWAP
}

inline uint64_t rdtsc()
{
    uint32_t tickl, tickh;
    __asm__ __volatile__("rdtsc"
                         : "=a"(tickl), "=d"(tickh));
    return (static_cast<uint64_t>(tickh) << 32) | tickl;
}

void benchmark()
{
    std::mt19937 engine(std::random_device{}());
    std::uniform_int_distribution<int> dist;

    constexpr const int N = 80 * 1000 * 1000;
    std::vector<int32_t> v(N);
    for (int i = 0; i < N; i++) {
        v[i] = dist(engine);
    }
    auto t0 = rdtsc();
    for (int n = 0; n < N; n += 8) {
        srtnw::SortingNetwork<srtnw::BoseNelson<8>>::sort(v.data() + n);
        // network_sort_8n(v.data() + n);
        // std::sort(v.begin() + n, v.begin() + n + 8);
    }
    auto t1 = rdtsc();
    std::cout << (t1 - t0) / 1000 / 1000 << "[Mcycles]" << std::endl;
    std::cout << std::boolalpha << std::is_sorted(v.begin(), v.end()) << std::endl;
}

int main()
{
    benchmark();
}

コンパイラはclang7 結果はそれぞれ3回平均

項目 最適化 結果(Mcycles)
BoseNelson O0 9224
BoseNelson O3 239
network_sort_8n O0 9170
network_sort_8n O3 229
std::sort O0 16940
std::sort O3 2194

僅かにnetwork_sort_8nのほうが早いけどまあ十分な性能だと思う。

どうでもいい余談

今までマークダウンでのリンクでカッコの順番をしょっちゅう間違えていたのだけれど、[テキスト](アドレス)の並びがC++でのラムダ式のカッコの順と同じと気付いてから間違えなくなった。


  1. 刈り込み平均、メディアンフィルタ 

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3