アルゴリズム
Rust
SIMD
ソート

Rustでソートアルゴリズム (7)バイトニックソート (SIMDで高速化)

More than 1 year has passed since last update.

概要

並列化が容易なため、GPGPU等で有用であるバイトニックソートを実装する。
Wikipediaでの説明(英語)

ナイーブな実装

前提条件

前回と同様にRustのバージョンは1.19.0とします。

% rustc --version
rustc 1.19.0 (0ade33941 2017-07-17)

ソースコード

BitonicSort
fn sort<T: PartialOrd>(&self, source: &mut [T]) {
    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    let mut i = 2;
    while i <= size {
        let mut j = i >> 1;
        while j > 0 {
            let ml = j - 1; // 下位ビットマスク
            let mh = !ml; // 上位ビットマスク

            for k in 0..half_size {
                let l = ((k & mh) << 1) | (k & ml);
                let m = l + j;

                if ((l & i) == 0) ^ (source[l] < source[m]) {
                    source.swap(l, m);
                }
            }
            j >>= 1;
        }
        i <<= 1;
    }
}

$O(n \log^2 n)$のため、あまり早くない。
ただし、このソートは、並列化が容易なため、SIMDで高速化してみる。

SIMDを使う

事前作業

nightlyに変更

% rustup run nightly rustc --version
rustc 1.21.0-nightly (dd53dd5f9 2017-08-01)

ポイント

SIMDを使用するにあたって、simd crateを使うことが一般的だが、今回はplatform-intrinsicを使うこととした。

参考資料

http://qiita.com/tatsuya6502/items/7ffc623fc60be0220409
http://mayah.jp/article/2016/x86intrin/

ソースコード(i32, SSE4.1版)

BitonicSort(SSE4.1)
#![feature(repr_simd, platform_intrinsics)]

#[allow(non_camel_case_types)]
#[repr(simd)]
#[derive(Debug, Copy, Clone)]
struct i32x4(i32, i32, i32, i32);

extern "platform-intrinsic" {
    fn simd_shuffle4<T, U>(x: T, y: T, idx: [u32; 4]) -> U;
    fn x86_mm_min_epi32(x: i32x4, y: i32x4) -> i32x4;
    fn x86_mm_max_epi32(x: i32x4, y: i32x4) -> i32x4;
}

fn sort(source: &mut [i32]) {
    #[inline]
    fn sort8(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
        unsafe {
            // L1
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = x86_mm_min_epi32(value1, value2);
            let value2 = simd_shuffle4(max, value1, [1, 0, 3, 2]);

            // L2
            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 7, 3]);
            let value2 = simd_shuffle4(min, max, [1, 5, 6, 2]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
            let value2 = simd_shuffle4(min, max, [6, 2, 7, 3]);

            // L3
            return sort8l3(value1, value2, rev);
        }
    }

    #[inline]
    fn sort8l3(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
        unsafe {
            // L3
            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 1, 4, 5]);
            let value2 = simd_shuffle4(min, max, [2, 3, 6, 7]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            let value1 = simd_shuffle4(min, max, [0, 4, 2, 6]);
            let value2 = simd_shuffle4(min, max, [1, 5, 3, 7]);

            let min = x86_mm_min_epi32(value1, value2);
            let max = x86_mm_max_epi32(value1, value2);
            if rev {
                let value1 = simd_shuffle4(min, max, [7, 3, 6, 2]);
                let value2 = simd_shuffle4(min, max, [5, 1, 4, 0]);
                return (value1, value2);
            } else {
                let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
                let value2 = simd_shuffle4(min, max, [2, 6, 3, 7]);
                return (value1, value2);
            }
        }
    }

    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    if size >= 8 {
        let size_simd = size >> 2;
        let half_size_simd = half_size >> 2;
        let mut work = Vec::<i32x4>::with_capacity(size_simd);
        for k in 0..half_size_simd {
            let l = k << 3;

            let (value1, value2) =
                sort8(
                    i32x4(source[l], source[l + 1], source[l + 2], source[l + 3]),
                    i32x4(source[l + 4], source[l + 5], source[l + 6], source[l + 7]),
                    k & 1 != 0,
                );
            work.push(value1);
            work.push(value2);
        }

        let mut i = 4;
        while i <= size_simd {
            let mut j = i >> 1;
            while j > 1 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size_simd {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;
                    let isnot_reverse = (l & i) == 0;

                    unsafe {
                        let max = x86_mm_max_epi32(work[l], work[m]);
                        let min = x86_mm_min_epi32(work[l], work[m]);

                        if isnot_reverse {
                            work[l] = min;
                            work[m] = max;
                        } else {
                            work[l] = max;
                            work[m] = min;
                        }
                    }
                }
                j >>= 1;
            }

            for k in 0..half_size_simd {
                let l = k << 1;
                let m = l + 1;
                let is_reverse = (l & i) != 0;

                let (value1, value2) = sort8l3(work[l], work[m], is_reverse);
                work[l] = value1;
                work[m] = value2;
            }
            i <<= 1;
        }

        // 書き戻し
        for j in 0..work.len() {
            source[(j << 2)] = work[j].0;
            source[(j << 2) + 1] = work[j].1;
            source[(j << 2) + 2] = work[j].2;
            source[(j << 2) + 3] = work[j].3;
        }
    } else {
        let mut i = 2;
        while i <= size {
            let mut j = i >> 1;
            while j > 0 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;

                    if ((l & i) == 0) ^ (source[l] < source[m]) {
                        source.swap(l, m);
                    }
                }
                j >>= 1;
            }
            i <<= 1;
        }
    }
}

ソースコード(i32, AVX2版)

BitonicSort(AVX2)
#![feature(repr_simd, platform_intrinsics)]

#[allow(non_camel_case_types)]
#[repr(simd)]
#[derive(Debug, Copy, Clone)]
struct i32x8(i32, i32, i32, i32, i32, i32, i32, i32);

extern "platform-intrinsic" {
    fn simd_shuffle8<T, U>(x: T, y: T, idx: [u32; 8]) -> U;
    fn x86_mm256_min_epi32(x: i32x8, y: i32x8) -> i32x8;
    fn x86_mm256_max_epi32(x: i32x8, y: i32x8) -> i32x8;
}

fn sort(source: &mut [i32]) {
    #[inline]
    fn sort16l4(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
        unsafe {
            // L4
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,5,6,7,8
            let max = x86_mm256_max_epi32(value1, value2); // 9,10,11,12,13,14,15,16
            let value1 = simd_shuffle8(min, max, [0, 1, 2, 3, 8, 9, 10, 11]); // 1,2,3,4,9,10,11,12
            let value2 = simd_shuffle8(min, max, [4, 5, 6, 7, 12, 13, 14, 15]); // 5,6,7,8,13,14,15,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,9,10,11,12
            let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,13,14,15,16
            let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,9,10,13,14
            let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,11,12,15,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,9,10,13,14
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,11,12,15,16
            let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 4, 12, 6, 14]); // 1,3,5,7,9,11,13,15
            let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 5, 13, 7, 15]); // 2,4,6,8,10,12,14,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,9,11,13,15
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,10,12,14,16
            if rev {
                let value1 = simd_shuffle8(min, max, [15, 7, 14, 6, 13, 5, 12, 4]);
                let value2 = simd_shuffle8(min, max, [11, 3, 10, 2, 9, 1, 8, 0]);
                return (value1, value2);
            } else {
                let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]);
                let value2 = simd_shuffle8(min, max, [4, 12, 5, 13, 6, 14, 7, 15]);
                return (value1, value2);
            }
        }
    }
    #[inline]
    fn sort16(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
        unsafe {
            // L1
            let max = x86_mm256_max_epi32(value1, value2); // 2,3,6,7,10,11,14,15
            let value1 = x86_mm256_min_epi32(value1, value2); // 1,4,5,8,9,12,13,16
            let value2 = simd_shuffle8(max, value1, [1, 0, 3, 2, 5, 4, 7, 6]); // 3,2,7,6,11,10,15,14

            // L2
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,7,8,9,10,15,16
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,5,6,11,12,13,14
            let value1 = simd_shuffle8(min, max, [0, 8, 11, 3, 4, 12, 15, 7]); // 1,3,6,8,9,11,14,16
            let value2 = simd_shuffle8(min, max, [1, 9, 10, 2, 5, 13, 14, 6]); // 2,4,5,7,10,12,13,15

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,6,8,9,11,14,16
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,5,7,10,12,13,15
            let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 4, 12, 5, 13]); // 1,2,3,4,9,10,11,12
            let value2 = simd_shuffle8(min, max, [10, 2, 11, 3, 14, 6, 15, 7]); // 5,6,7,8,13,14,15,16

            // L3
            let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,13,14,15,16
            let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,9,10,11,12
            let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,13,14,9,10
            let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,15,16,11,12

            let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,15,16,11,12
            let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,13,14,9,10
            let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 14, 6, 12, 4]); // 1,3,5,7,9,11,13,15
            let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 15, 7, 13, 5]); // 2,4,6,8,10,12,14,16

            let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,10,12,14,16
            let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,9,11,13,15
            let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]); // 1,2,3,4,5,6,7,8
            let value2 = simd_shuffle8(min, max, [12, 4, 13, 5, 14, 6, 15, 7]); // 9,10,11,12,13,14,15

            // L4
            return sort16l4(value1, value2, rev);
        }
    }

    let source_size = source.len();
    let size = source_size.next_power_of_two();
    if source_size != size {
        return;
    }
    let half_size = size >> 1;

    if size >= 16 {
        let size_simd = size >> 3;
        let half_size_simd = half_size >> 3;
        let mut work = Vec::<i32x8>::with_capacity(size_simd);
        for k in 0..half_size_simd {
            let l = k << 4;

            let (value1, value2) = sort16(
                i32x8(
                    source[l],
                    source[l + 1],
                    source[l + 2],
                    source[l + 3],
                    source[l + 4],
                    source[l + 5],
                    source[l + 6],
                    source[l + 7],
                ),
                i32x8(
                    source[l + 8],
                    source[l + 9],
                    source[l + 10],
                    source[l + 11],
                    source[l + 12],
                    source[l + 13],
                    source[l + 14],
                    source[l + 15],
                ),
                k & 1 != 0,
            );
            work.push(value1);
            work.push(value2);
        }

        let mut i = 4;
        while i <= size_simd {
            let mut j = i >> 1;
            while j > 1 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size_simd {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;
                    let isnot_reverse = (l & i) == 0;

                    unsafe {
                        let max = x86_mm256_max_epi32(work[l], work[m]);
                        let min = x86_mm256_min_epi32(work[l], work[m]);

                        if isnot_reverse {
                            work[l] = min;
                            work[m] = max;
                        } else {
                            work[l] = max;
                            work[m] = min;
                        }
                    }
                }
                j >>= 1;
            }

            for k in 0..half_size_simd {
                let l = k << 1;
                let m = l + 1;
                let is_reverse = (l & i) != 0;

                let (value1, value2) = sort16l4(work[l], work[m], is_reverse);
                work[l] = value1;
                work[m] = value2;
            }
            i <<= 1;
        }

        // 書き戻し
        for j in 0..work.len() {
            source[(j << 3)] = work[j].0;
            source[(j << 3) + 1] = work[j].1;
            source[(j << 3) + 2] = work[j].2;
            source[(j << 3) + 3] = work[j].3;
            source[(j << 3) + 4] = work[j].4;
            source[(j << 3) + 5] = work[j].5;
            source[(j << 3) + 6] = work[j].6;
            source[(j << 3) + 7] = work[j].7;
        }
    } else {
        let mut i = 2;
        while i <= size {
            let mut j = i >> 1;
            while j > 0 {
                let ml = j - 1; // 下位ビットマスク
                let mh = !ml; // 上位ビットマスク

                for k in 0..half_size {
                    let l = ((k & mh) << 1) | (k & ml);
                    let m = l + j;

                    if ((l & i) == 0) ^ (source[l] < source[m]) {
                        source.swap(l, m);
                    }
                }
                j >>= 1;
            }
            i <<= 1;
        }
    }
}

かなり高速化した。要素数次第ではあるが、標準ソートより速い。

ベンチマーク

Rustでソートアルゴリズム (0)まとめ