10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Rustでソートアルゴリズム (7)バイトニックソート (SIMDで高速化)

Posted at

概要

並列化が容易なため、GPGPU等で有用であるバイトニックソートを実装する。
Wikipediaでの説明(英語)

ナイーブな実装

前提条件

前回と同様にRustのバージョンは1.19.0とします。

% rustc --version
rustc 1.19.0 (0ade33941 2017-07-17)

ソースコード

BitonicSort
fn sort<T: PartialOrd>(&self, source: &mut [T]) {
    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    let mut i = 2;
    while i <= size {
        let mut j = i >> 1;
        while j > 0 {
            let ml = j - 1; // 下位ビットマスク
            let mh = !ml; // 上位ビットマスク

            for k in 0..half_size {
                let l = ((k & mh) << 1) | (k & ml);
                let m = l + j;

                if ((l & i) == 0) ^ (source[l] < source[m]) {
                    source.swap(l, m);
                }
            }
            j >>= 1;
        }
        i <<= 1;
    }
}

$O(n \log^2 n)$のため、あまり早くない。
ただし、このソートは、並列化が容易なため、SIMDで高速化してみる。

SIMDを使う

事前作業

nightlyに変更

% rustup run nightly rustc --version
rustc 1.21.0-nightly (dd53dd5f9 2017-08-01)

ポイント

SIMDを使用するにあたって、simd crateを使うことが一般的だが、今回はplatform-intrinsicを使うこととした。

参考資料

http://qiita.com/tatsuya6502/items/7ffc623fc60be0220409
http://mayah.jp/article/2016/x86intrin/

ソースコード(i32, SSE4.1版)

BitonicSort(SSE4.1)
# ![feature(repr_simd, platform_intrinsics)]

# [allow(non_camel_case_types)]
# [repr(simd)]
# [derive(Debug, Copy, Clone)]
struct i32x4(i32, i32, i32, i32);

extern "platform-intrinsic" {
    fn simd_shuffle4<T, U>(x: T, y: T, idx: [u32; 4]) -> U;
    fn x86_mm_min_epi32(x: i32x4, y: i32x4) -> i32x4;
    fn x86_mm_max_epi32(x: i32x4, y: i32x4) -> i32x4;
}

fn sort(source: &mut [i32]) {
    #[inline]
    fn sort8(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
        unsafe {
            // L1
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = x86_mm_min_epi32(value1, value2);
            let value2 = simd_shuffle4(max, value1, [1, 0, 3, 2]);

            // L2
            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 7, 3]);
            let value2 = simd_shuffle4(min, max, [1, 5, 6, 2]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
            let value2 = simd_shuffle4(min, max, [6, 2, 7, 3]);

            // L3
            return sort8l3(value1, value2, rev);
        }
    }

    #[inline]
    fn sort8l3(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
        unsafe {
            // L3
            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 1, 4, 5]);
            let value2 = simd_shuffle4(min, max, [2, 3, 6, 7]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 2, 6]);
            let value2 = simd_shuffle4(min, max, [1, 5, 3, 7]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            if rev {
                let value1 = simd_shuffle4(min, max, [7, 3, 6, 2]);
                let value2 = simd_shuffle4(min, max, [5, 1, 4, 0]);
                return (value1, value2);
            } else {
                let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
                let value2 = simd_shuffle4(min, max, [2, 6, 3, 7]);
                return (value1, value2);
            }
        }
    }

    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    if size >= 8 {
        let size_simd = size >> 2;
        let half_size_simd = half_size >> 2;
        let mut work = Vec::<i32x4>::with_capacity(size_simd);
        for k in 0..half_size_simd {
            let l = k << 3;

            let (value1, value2) =
                sort8(
                    i32x4(source[l], source[l + 1], source[l + 2], source[l + 3]),
                    i32x4(source[l + 4], source[l + 5], source[l + 6], source[l + 7]),
                    k & 1 != 0,
                );
            work.push(value1);
            work.push(value2);
        }

        let mut i = 4;
        while i <= size_simd {
            let mut j = i >> 1;
            while j > 1 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size_simd {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;
                    let isnot_reverse = (l & i) == 0;

                    unsafe {
                        let max = x86_mm_max_epi32(work[l], work[m]);
                        let min = x86_mm_min_epi32(work[l], work[m]);

                        if isnot_reverse {
                            work[l] = min;
                            work[m] = max;
                        } else {
                            work[l] = max;
                            work[m] = min;
                        }
                    }
                }
                j >>= 1;
            }

            for k in 0..half_size_simd {
                let l = k << 1;
                let m = l + 1;
                let is_reverse = (l & i) != 0;

                let (value1, value2) = sort8l3(work[l], work[m], is_reverse);
                work[l] = value1;
                work[m] = value2;
            }
            i <<= 1;
        }

        // 書き戻し
        for j in 0..work.len() {
            source[(j << 2)] = work[j].0;
            source[(j << 2) + 1] = work[j].1;
            source[(j << 2) + 2] = work[j].2;
            source[(j << 2) + 3] = work[j].3;
        }
    } else {
        let mut i = 2;
        while i <= size {
            let mut j = i >> 1;
            while j > 0 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;

                    if ((l & i) == 0) ^ (source[l] < source[m]) {
                        source.swap(l, m);
                    }
                }
                j >>= 1;
            }
            i <<= 1;
        }
    }
}

ソースコード(i32, AVX2版)

BitonicSort(AVX2)
# ![feature(repr_simd, platform_intrinsics)]

# [allow(non_camel_case_types)]
# [repr(simd)]
# [derive(Debug, Copy, Clone)]
struct i32x8(i32, i32, i32, i32, i32, i32, i32, i32);

extern "platform-intrinsic" {
    fn simd_shuffle8<T, U>(x: T, y: T, idx: [u32; 8]) -> U;
    fn x86_mm256_min_epi32(x: i32x8, y: i32x8) -> i32x8;
    fn x86_mm256_max_epi32(x: i32x8, y: i32x8) -> i32x8;
}

fn sort(source: &mut [i32]) {
    #[inline]
    fn sort16l4(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
        unsafe {
            // L4
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,5,6,7,8
            let max = x86_mm256_max_epi32(value1, value2); // 9,10,11,12,13,14,15,16
            let value1 = simd_shuffle8(min, max, [0, 1, 2, 3, 8, 9, 10, 11]); // 1,2,3,4,9,10,11,12
            let value2 = simd_shuffle8(min, max, [4, 5, 6, 7, 12, 13, 14, 15]); // 5,6,7,8,13,14,15,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,9,10,11,12
            let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,13,14,15,16
            let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,9,10,13,14
            let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,11,12,15,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,9,10,13,14
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,11,12,15,16
            let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 4, 12, 6, 14]); // 1,3,5,7,9,11,13,15
            let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 5, 13, 7, 15]); // 2,4,6,8,10,12,14,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,9,11,13,15
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,10,12,14,16
            if rev {
                let value1 = simd_shuffle8(min, max, [15, 7, 14, 6, 13, 5, 12, 4]);
                let value2 = simd_shuffle8(min, max, [11, 3, 10, 2, 9, 1, 8, 0]);
                return (value1, value2);
            } else {
                let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]);
                let value2 = simd_shuffle8(min, max, [4, 12, 5, 13, 6, 14, 7, 15]);
                return (value1, value2);
            }
        }
    }
    #[inline]
    fn sort16(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
        unsafe {
            // L1
            let max = x86_mm256_max_epi32(value1, value2); // 2,3,6,7,10,11,14,15
            let value1 = x86_mm256_min_epi32(value1, value2); // 1,4,5,8,9,12,13,16
            let value2 = simd_shuffle8(max, value1, [1, 0, 3, 2, 5, 4, 7, 6]); // 3,2,7,6,11,10,15,14

            // L2
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,7,8,9,10,15,16
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,5,6,11,12,13,14
            let value1 = simd_shuffle8(min, max, [0, 8, 11, 3, 4, 12, 15, 7]); // 1,3,6,8,9,11,14,16
            let value2 = simd_shuffle8(min, max, [1, 9, 10, 2, 5, 13, 14, 6]); // 2,4,5,7,10,12,13,15

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,6,8,9,11,14,16
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,5,7,10,12,13,15
            let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 4, 12, 5, 13]); // 1,2,3,4,9,10,11,12
            let value2 = simd_shuffle8(min, max, [10, 2, 11, 3, 14, 6, 15, 7]); // 5,6,7,8,13,14,15,16

            // L3
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,13,14,15,16
            let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,9,10,11,12
            let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,13,14,9,10
            let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,15,16,11,12

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,15,16,11,12
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,13,14,9,10
            let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 14, 6, 12, 4]); // 1,3,5,7,9,11,13,15
            let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 15, 7, 13, 5]); // 2,4,6,8,10,12,14,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,10,12,14,16
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,9,11,13,15
            let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]); // 1,2,3,4,5,6,7,8
            let value2 = simd_shuffle8(min, max, [12, 4, 13, 5, 14, 6, 15, 7]); // 9,10,11,12,13,14,15

            // L4
            return sort16l4(value1, value2, rev);
        }
    }

    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    if size >= 16 {
        let size_simd = size >> 3;
        let half_size_simd = half_size >> 3;
        let mut work = Vec::<i32x8>::with_capacity(size_simd);
        for k in 0..half_size_simd {
            let l = k << 4;

            let (value1, value2) = sort16(
                i32x8(
                    source[l],
                    source[l + 1],
                    source[l + 2],
                    source[l + 3],
                    source[l + 4],
                    source[l + 5],
                    source[l + 6],
                    source[l + 7],
                ),
                i32x8(
                    source[l + 8],
                    source[l + 9],
                    source[l + 10],
                    source[l + 11],
                    source[l + 12],
                    source[l + 13],
                    source[l + 14],
                    source[l + 15],
                ),
                k & 1 != 0,
            );
            work.push(value1);
            work.push(value2);
        }

        let mut i = 4;
        while i <= size_simd {
            let mut j = i >> 1;
            while j > 1 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size_simd {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;
                    let isnot_reverse = (l & i) == 0;

                    unsafe {
                        let max = x86_mm256_max_epi32(work[l], work[m]);
                        let min = x86_mm256_min_epi32(work[l], work[m]);

                        if isnot_reverse {
                            work[l] = min;
                            work[m] = max;
                        } else {
                            work[l] = max;
                            work[m] = min;
                        }
                    }
                }
                j >>= 1;
            }

            for k in 0..half_size_simd {
                let l = k << 1;
                let m = l + 1;
                let is_reverse = (l & i) != 0;

                let (value1, value2) = sort16l4(work[l], work[m], is_reverse);
                work[l] = value1;
                work[m] = value2;
            }
            i <<= 1;
        }

        // 書き戻し
        for j in 0..work.len() {
            source[(j << 3)] = work[j].0;
            source[(j << 3) + 1] = work[j].1;
            source[(j << 3) + 2] = work[j].2;
            source[(j << 3) + 3] = work[j].3;
            source[(j << 3) + 4] = work[j].4;
            source[(j << 3) + 5] = work[j].5;
            source[(j << 3) + 6] = work[j].6;
            source[(j << 3) + 7] = work[j].7;
        }
    } else {
        let mut i = 2;
        while i <= size {
            let mut j = i >> 1;
            while j > 0 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;

                    if ((l & i) == 0) ^ (source[l] < source[m]) {
                        source.swap(l, m);
                    }
                }
                j >>= 1;
            }
            i <<= 1;
        }
    }
}

かなり高速化した。要素数次第ではあるが、標準ソートより速い。

ベンチマーク

Rustでソートアルゴリズム (0)まとめ

10
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?