概要
並列化が容易なため、GPGPU等で有用であるバイトニックソートを実装する。
Wikipediaでの説明(英語)
ナイーブな実装
前提条件
前回と同様にRustのバージョンは1.19.0とします。
% rustc --version
rustc 1.19.0 (0ade33941 2017-07-17)
ソースコード
BitonicSort
fn sort<T: PartialOrd>(&self, source: &mut [T]) {
let source_size = source.len();
let size = source_size.next_power_of_two();
if source_size != size {
return;
}
let half_size = size >> 1;
let mut i = 2;
while i <= size {
let mut j = i >> 1;
while j > 0 {
let ml = j - 1; // 下位ビットマスク
let mh = !ml; // 上位ビットマスク
for k in 0..half_size {
let l = ((k & mh) << 1) | (k & ml);
let m = l + j;
if ((l & i) == 0) ^ (source[l] < source[m]) {
source.swap(l, m);
}
}
j >>= 1;
}
i <<= 1;
}
}
$O(n \log^2 n)$のため、あまり早くない。
ただし、このソートは、並列化が容易なため、SIMDで高速化してみる。
SIMDを使う
事前作業
nightlyに変更
% rustup run nightly rustc --version
rustc 1.21.0-nightly (dd53dd5f9 2017-08-01)
ポイント
SIMDを使用するにあたって、simd crateを使うことが一般的だが、今回はplatform-intrinsicを使うこととした。
参考資料
http://qiita.com/tatsuya6502/items/7ffc623fc60be0220409
http://mayah.jp/article/2016/x86intrin/
ソースコード(i32, SSE4.1版)
BitonicSort(SSE4.1)
# ![feature(repr_simd, platform_intrinsics)]
# [allow(non_camel_case_types)]
# [repr(simd)]
# [derive(Debug, Copy, Clone)]
struct i32x4(i32, i32, i32, i32);
extern "platform-intrinsic" {
fn simd_shuffle4<T, U>(x: T, y: T, idx: [u32; 4]) -> U;
fn x86_mm_min_epi32(x: i32x4, y: i32x4) -> i32x4;
fn x86_mm_max_epi32(x: i32x4, y: i32x4) -> i32x4;
}
fn sort(source: &mut [i32]) {
#[inline]
fn sort8(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
unsafe {
// L1
let max = x86_mm_max_epi32(value1, value2);
let value1 = x86_mm_min_epi32(value1, value2);
let value2 = simd_shuffle4(max, value1, [1, 0, 3, 2]);
// L2
let min = x86_mm_min_epi32(value1, value2);
let max = x86_mm_max_epi32(value1, value2);
let value1 = simd_shuffle4(min, max, [0, 4, 7, 3]);
let value2 = simd_shuffle4(min, max, [1, 5, 6, 2]);
let min = x86_mm_min_epi32(value1, value2);
let max = x86_mm_max_epi32(value1, value2);
let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
let value2 = simd_shuffle4(min, max, [6, 2, 7, 3]);
// L3
return sort8l3(value1, value2, rev);
}
}
#[inline]
fn sort8l3(value1: i32x4, value2: i32x4, rev: bool) -> (i32x4, i32x4) {
unsafe {
// L3
let min = x86_mm_min_epi32(value1, value2);
let max = x86_mm_max_epi32(value1, value2);
let value1 = simd_shuffle4(min, max, [0, 1, 4, 5]);
let value2 = simd_shuffle4(min, max, [2, 3, 6, 7]);
let min = x86_mm_min_epi32(value1, value2);
let max = x86_mm_max_epi32(value1, value2);
let value1 = simd_shuffle4(min, max, [0, 4, 2, 6]);
let value2 = simd_shuffle4(min, max, [1, 5, 3, 7]);
let min = x86_mm_min_epi32(value1, value2);
let max = x86_mm_max_epi32(value1, value2);
if rev {
let value1 = simd_shuffle4(min, max, [7, 3, 6, 2]);
let value2 = simd_shuffle4(min, max, [5, 1, 4, 0]);
return (value1, value2);
} else {
let value1 = simd_shuffle4(min, max, [0, 4, 1, 5]);
let value2 = simd_shuffle4(min, max, [2, 6, 3, 7]);
return (value1, value2);
}
}
}
let source_size = source.len();
let size = source_size.next_power_of_two();
if source_size != size {
return;
}
let half_size = size >> 1;
if size >= 8 {
let size_simd = size >> 2;
let half_size_simd = half_size >> 2;
let mut work = Vec::<i32x4>::with_capacity(size_simd);
for k in 0..half_size_simd {
let l = k << 3;
let (value1, value2) =
sort8(
i32x4(source[l], source[l + 1], source[l + 2], source[l + 3]),
i32x4(source[l + 4], source[l + 5], source[l + 6], source[l + 7]),
k & 1 != 0,
);
work.push(value1);
work.push(value2);
}
let mut i = 4;
while i <= size_simd {
let mut j = i >> 1;
while j > 1 {
let ml = j - 1; // 下位ビットマスク
let mh = !ml; // 上位ビットマスク
for k in 0..half_size_simd {
let l = ((k & mh) << 1) | (k & ml);
let m = l + j;
let isnot_reverse = (l & i) == 0;
unsafe {
let max = x86_mm_max_epi32(work[l], work[m]);
let min = x86_mm_min_epi32(work[l], work[m]);
if isnot_reverse {
work[l] = min;
work[m] = max;
} else {
work[l] = max;
work[m] = min;
}
}
}
j >>= 1;
}
for k in 0..half_size_simd {
let l = k << 1;
let m = l + 1;
let is_reverse = (l & i) != 0;
let (value1, value2) = sort8l3(work[l], work[m], is_reverse);
work[l] = value1;
work[m] = value2;
}
i <<= 1;
}
// 書き戻し
for j in 0..work.len() {
source[(j << 2)] = work[j].0;
source[(j << 2) + 1] = work[j].1;
source[(j << 2) + 2] = work[j].2;
source[(j << 2) + 3] = work[j].3;
}
} else {
let mut i = 2;
while i <= size {
let mut j = i >> 1;
while j > 0 {
let ml = j - 1; // 下位ビットマスク
let mh = !ml; // 上位ビットマスク
for k in 0..half_size {
let l = ((k & mh) << 1) | (k & ml);
let m = l + j;
if ((l & i) == 0) ^ (source[l] < source[m]) {
source.swap(l, m);
}
}
j >>= 1;
}
i <<= 1;
}
}
}
ソースコード(i32, AVX2版)
BitonicSort(AVX2)
# ![feature(repr_simd, platform_intrinsics)]
# [allow(non_camel_case_types)]
# [repr(simd)]
# [derive(Debug, Copy, Clone)]
struct i32x8(i32, i32, i32, i32, i32, i32, i32, i32);
extern "platform-intrinsic" {
fn simd_shuffle8<T, U>(x: T, y: T, idx: [u32; 8]) -> U;
fn x86_mm256_min_epi32(x: i32x8, y: i32x8) -> i32x8;
fn x86_mm256_max_epi32(x: i32x8, y: i32x8) -> i32x8;
}
fn sort(source: &mut [i32]) {
#[inline]
fn sort16l4(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
unsafe {
// L4
let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,5,6,7,8
let max = x86_mm256_max_epi32(value1, value2); // 9,10,11,12,13,14,15,16
let value1 = simd_shuffle8(min, max, [0, 1, 2, 3, 8, 9, 10, 11]); // 1,2,3,4,9,10,11,12
let value2 = simd_shuffle8(min, max, [4, 5, 6, 7, 12, 13, 14, 15]); // 5,6,7,8,13,14,15,16
let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,9,10,11,12
let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,13,14,15,16
let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,9,10,13,14
let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,11,12,15,16
let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,9,10,13,14
let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,11,12,15,16
let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 4, 12, 6, 14]); // 1,3,5,7,9,11,13,15
let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 5, 13, 7, 15]); // 2,4,6,8,10,12,14,16
let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,9,11,13,15
let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,10,12,14,16
if rev {
let value1 = simd_shuffle8(min, max, [15, 7, 14, 6, 13, 5, 12, 4]);
let value2 = simd_shuffle8(min, max, [11, 3, 10, 2, 9, 1, 8, 0]);
return (value1, value2);
} else {
let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]);
let value2 = simd_shuffle8(min, max, [4, 12, 5, 13, 6, 14, 7, 15]);
return (value1, value2);
}
}
}
#[inline]
fn sort16(value1: i32x8, value2: i32x8, rev: bool) -> (i32x8, i32x8) {
unsafe {
// L1
let max = x86_mm256_max_epi32(value1, value2); // 2,3,6,7,10,11,14,15
let value1 = x86_mm256_min_epi32(value1, value2); // 1,4,5,8,9,12,13,16
let value2 = simd_shuffle8(max, value1, [1, 0, 3, 2, 5, 4, 7, 6]); // 3,2,7,6,11,10,15,14
// L2
let min = x86_mm256_min_epi32(value1, value2); // 1,2,7,8,9,10,15,16
let max = x86_mm256_max_epi32(value1, value2); // 3,4,5,6,11,12,13,14
let value1 = simd_shuffle8(min, max, [0, 8, 11, 3, 4, 12, 15, 7]); // 1,3,6,8,9,11,14,16
let value2 = simd_shuffle8(min, max, [1, 9, 10, 2, 5, 13, 14, 6]); // 2,4,5,7,10,12,13,15
let min = x86_mm256_min_epi32(value1, value2); // 1,3,6,8,9,11,14,16
let max = x86_mm256_max_epi32(value1, value2); // 2,4,5,7,10,12,13,15
let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 4, 12, 5, 13]); // 1,2,3,4,9,10,11,12
let value2 = simd_shuffle8(min, max, [10, 2, 11, 3, 14, 6, 15, 7]); // 5,6,7,8,13,14,15,16
// L3
let min = x86_mm256_min_epi32(value1, value2); // 1,2,3,4,13,14,15,16
let max = x86_mm256_max_epi32(value1, value2); // 5,6,7,8,9,10,11,12
let value1 = simd_shuffle8(min, max, [0, 1, 8, 9, 4, 5, 12, 13]); // 1,2,5,6,13,14,9,10
let value2 = simd_shuffle8(min, max, [2, 3, 10, 11, 6, 7, 14, 15]); // 3,4,7,8,15,16,11,12
let min = x86_mm256_min_epi32(value1, value2); // 1,2,5,6,15,16,11,12
let max = x86_mm256_max_epi32(value1, value2); // 3,4,7,8,13,14,9,10
let value1 = simd_shuffle8(min, max, [0, 8, 2, 10, 14, 6, 12, 4]); // 1,3,5,7,9,11,13,15
let value2 = simd_shuffle8(min, max, [1, 9, 3, 11, 15, 7, 13, 5]); // 2,4,6,8,10,12,14,16
let min = x86_mm256_min_epi32(value1, value2); // 1,3,5,7,10,12,14,16
let max = x86_mm256_max_epi32(value1, value2); // 2,4,6,8,9,11,13,15
let value1 = simd_shuffle8(min, max, [0, 8, 1, 9, 2, 10, 3, 11]); // 1,2,3,4,5,6,7,8
let value2 = simd_shuffle8(min, max, [12, 4, 13, 5, 14, 6, 15, 7]); // 9,10,11,12,13,14,15
// L4
return sort16l4(value1, value2, rev);
}
}
let source_size = source.len();
let size = source_size.next_power_of_two();
if source_size != size {
return;
}
let half_size = size >> 1;
if size >= 16 {
let size_simd = size >> 3;
let half_size_simd = half_size >> 3;
let mut work = Vec::<i32x8>::with_capacity(size_simd);
for k in 0..half_size_simd {
let l = k << 4;
let (value1, value2) = sort16(
i32x8(
source[l],
source[l + 1],
source[l + 2],
source[l + 3],
source[l + 4],
source[l + 5],
source[l + 6],
source[l + 7],
),
i32x8(
source[l + 8],
source[l + 9],
source[l + 10],
source[l + 11],
source[l + 12],
source[l + 13],
source[l + 14],
source[l + 15],
),
k & 1 != 0,
);
work.push(value1);
work.push(value2);
}
let mut i = 4;
while i <= size_simd {
let mut j = i >> 1;
while j > 1 {
let ml = j - 1; // 下位ビットマスク
let mh = !ml; // 上位ビットマスク
for k in 0..half_size_simd {
let l = ((k & mh) << 1) | (k & ml);
let m = l + j;
let isnot_reverse = (l & i) == 0;
unsafe {
let max = x86_mm256_max_epi32(work[l], work[m]);
let min = x86_mm256_min_epi32(work[l], work[m]);
if isnot_reverse {
work[l] = min;
work[m] = max;
} else {
work[l] = max;
work[m] = min;
}
}
}
j >>= 1;
}
for k in 0..half_size_simd {
let l = k << 1;
let m = l + 1;
let is_reverse = (l & i) != 0;
let (value1, value2) = sort16l4(work[l], work[m], is_reverse);
work[l] = value1;
work[m] = value2;
}
i <<= 1;
}
// 書き戻し
for j in 0..work.len() {
source[(j << 3)] = work[j].0;
source[(j << 3) + 1] = work[j].1;
source[(j << 3) + 2] = work[j].2;
source[(j << 3) + 3] = work[j].3;
source[(j << 3) + 4] = work[j].4;
source[(j << 3) + 5] = work[j].5;
source[(j << 3) + 6] = work[j].6;
source[(j << 3) + 7] = work[j].7;
}
} else {
let mut i = 2;
while i <= size {
let mut j = i >> 1;
while j > 0 {
let ml = j - 1; // 下位ビットマスク
let mh = !ml; // 上位ビットマスク
for k in 0..half_size {
let l = ((k & mh) << 1) | (k & ml);
let m = l + j;
if ((l & i) == 0) ^ (source[l] < source[m]) {
source.swap(l, m);
}
}
j >>= 1;
}
i <<= 1;
}
}
}
かなり高速化した。要素数次第ではあるが、標準ソートより速い。