アルゴリズム
rust
ソート

Rustでソートアルゴリズム (5)クイックソート・イントロソート

概要

比較を行うソートにおいて一番速いことが多いクイックソートを実装し、最悪パターン対応を行ったイントロソートを実装する。

前提条件

これまでと同様にRustのバージョンは1.19.0とします。

% rustc --version
rustc 1.19.0 (0ade33941 2017-07-17)

クイックソート

Wikipediaでの説明

クイックソート
fn sort<T: PartialOrd + Clone>(source: &mut [T]) {
    fn q_sort<TInner: PartialOrd + Clone>(source: &mut [TInner], left: usize, right: usize) {
        let pivot = source[(left + right) >> 1].clone();
        let mut l = left;
        let mut r = right;
        while l <= r {
            while pivot < source[r] && r > left {
                r -= 1;
            }
            while source[l] < pivot && l < right {
                l += 1;
            }
            if l <= r {
                source.swap(l, r);
                if r > 0 {
                    r -= 1;
                }
                l += 1;
            }
        }
        if left < r {
            q_sort(source, left, r);
        }
        if right > l {
            q_sort(source, l, right);
        }
    }

    let size = source.len() - 1;
    q_sort(source, 0, size);
}

ちょっとイマイチの速度のため、前回と同様に挿入ソートと組み合わせてみる。

クイックソート+挿入ソート
fn sort<T: PartialOrd + Clone>(source: &mut [T]) {
    fn q_sort<TInner: PartialOrd + Clone>(source: &mut [TInner], left: usize, right: usize) {
        const INSERT_THRESHOLD: usize = 32;
        let pivot = source[(left + right) >> 1].clone();
        let mut l = left;
        let mut r = right;
        while l <= r {
            while pivot < source[r] && r > left {
                r -= 1;
            }
            while source[l] < pivot && l < right {
                l += 1;
            }
            if l <= r {
                source.swap(l, r);
                if r > 0 {
                    r -= 1;
                }
                l += 1;
            }
        }
        if r > INSERT_THRESHOLD + left {
            q_sort(source, left, r);
        }
        if right > INSERT_THRESHOLD + l {
            q_sort(source, l, right);
        }
    }

    fn insert_sort<TInner: PartialOrd>(source: &mut [TInner]) {
        let ptr = source.as_mut_ptr();
        for i in 1..source.len() {
            let mut j = i;
            unsafe {
                let mut t: TInner = mem::uninitialized();
                ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
                while 0 < j && *(ptr.offset((j - 1) as isize)) > t {
                    ptr::copy_nonoverlapping(
                        ptr.offset((j - 1) as isize),
                        ptr.offset(j as isize),
                        1,
                    );
                    j -= 1;
                }
                ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
                mem::forget(t);
            }
        }
    }

    let size = source.len() - 1;
    q_sort(source, 0, size);
    insert_sort(source);
}

だいぶ改善した。
高速化のため、unsafeにしてみる。

クイックソート+挿入ソート(unsafe)
fn sort<T: PartialOrd>(source: &mut [T]) {
    fn q_sort<TInner: PartialOrd>(ptr: *mut TInner, left: isize, right: isize) {
        const INSERT_THRESHOLD: isize = 32;
        let mut l = left;
        let mut r = right;
        unsafe {
            let mut pivot: TInner = mem::uninitialized();
            let mut tmp: TInner = mem::uninitialized();
            ptr::copy_nonoverlapping(ptr.offset(((left + right) >> 1)), &mut pivot, 1);
            while l <= r {
                while r > left && pivot < *(ptr.offset(r)) {
                    r -= 1;
                }
                while l < right && *(ptr.offset(l)) < pivot {
                    l += 1;
                }
                if l <= r {
                    ptr::copy_nonoverlapping(ptr.offset(l), &mut tmp, 1);
                    ptr::copy_nonoverlapping(ptr.offset(r), ptr.offset(l), 1);
                    ptr::copy_nonoverlapping(&mut tmp, ptr.offset(r), 1);
                    if r > 0 {
                        r -= 1;
                    }
                    l += 1;
                }
            }
            if r > INSERT_THRESHOLD + left {
                q_sort(ptr, left, r);
            }
            if right > INSERT_THRESHOLD + l {
                q_sort(ptr, l, right);
            }
            mem::forget(pivot);
            mem::forget(tmp);
        }
    }

    fn insert_sort<TInner: PartialOrd>(source: &mut [TInner]) {
        let ptr = source.as_mut_ptr();
        for i in 1..source.len() {
            let mut j = i;
            unsafe {
                let mut t: TInner = mem::uninitialized();
                ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
                while 0 < j && *(ptr.offset((j - 1) as isize)) > t {
                    ptr::copy_nonoverlapping(
                        ptr.offset((j - 1) as isize),
                        ptr.offset(j as isize),
                        1,
                    );
                    j -= 1;
                }
                ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
                mem::forget(t);
            }
        }
    }

    let size = source.len() - 1;
    q_sort(source.as_mut_ptr(), 0, size as isize);
    insert_sort(source);
}

QuickSortは、$O(n^2)$の時間がかかるようなデータを投入することにより攻撃ができる。
これは、pivotの選び方によりある程度回避できる。
今回は3データの中央値を取る方法に変更してみる。

クイックソート(3値中央選択)+挿入ソート(unsafe)
fn sort<T: PartialOrd>(source: &mut [T]) {
    fn q_sort<TInner: PartialOrd>(ptr: *mut TInner, left: isize, right: isize) {
        fn get_pivot<TPivot: PartialOrd>(
            ptr: *mut TPivot,
            left: isize,
            right: isize,
        ) -> TPivot {
            unsafe {
                let mut m: TPivot = mem::uninitialized();
                let mut l: TPivot = mem::uninitialized();
                let mut r: TPivot = mem::uninitialized();
                ptr::copy_nonoverlapping(ptr.offset((left + right) >> 1), &mut m, 1);
                ptr::copy_nonoverlapping(ptr.offset(left), &mut l, 1);
                ptr::copy_nonoverlapping(ptr.offset(right), &mut r, 1);
                if m < l {
                    if l < r {
                        l
                    } else if m < r {
                        r
                    } else {
                        m
                    }
                } else {
                    if m < r {
                        m
                    } else if r < l {
                        l
                    } else {
                        r
                    }
                }
            }
        }

        const INSERT_THRESHOLD: isize = 32;
        let mut l = left;
        let mut r = right;
        unsafe {
            let mut tmp: TInner = mem::uninitialized();
            let pivot = get_pivot(ptr, left, right);
            while l <= r {
                while r > left && pivot < *(ptr.offset(r)) {
                    r -= 1;
                }
                while l < right && *(ptr.offset(l)) < pivot {
                    l += 1;
                }
                if l <= r {
                    ptr::copy_nonoverlapping(ptr.offset(l), &mut tmp, 1);
                    ptr::copy_nonoverlapping(ptr.offset(r), ptr.offset(l), 1);
                    ptr::copy_nonoverlapping(&mut tmp, ptr.offset(r), 1);
                    if r > 0 {
                        r -= 1;
                    }
                    l += 1;
                }
            }
            if r > INSERT_THRESHOLD + left {
                q_sort(ptr, left, r);
            }
            if right > INSERT_THRESHOLD + l {
                q_sort(ptr, l, right);
            }
            mem::forget(pivot);
            mem::forget(tmp);
        }
    }

    fn insert_sort<TInner: PartialOrd>(source: &mut [TInner]) {
        let ptr = source.as_mut_ptr();
        for i in 1..source.len() {
            let mut j = i;
            unsafe {
                let mut t: TInner = mem::uninitialized();
                ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
                while 0 < j && *(ptr.offset((j - 1) as isize)) > t {
                    ptr::copy_nonoverlapping(
                        ptr.offset((j - 1) as isize),
                        ptr.offset(j as isize),
                        1,
                    );
                    j -= 1;
                }
                ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
                mem::forget(t);
            }
        }
    }

    let size = source.len() - 1;
    q_sort(source.as_mut_ptr(), 0, size as isize);
    insert_sort(source);
}

イントロソート

クイックソートのもう一つの攻撃回避方法が、ソートの深さが一定を超過した場合に、ヒープソートに切り替える方法である。
ヒープソートは最悪計算量が$O(n\log n)$であることから、計算量の増加が抑えられる。
コードは次の通り。

IntroSort
fn sort<T: PartialOrd>(source: &mut [T]) {
    fn q_sort<TInner: PartialOrd>(ptr: *mut TInner, left: isize, right: isize, depth: usize) {
        fn heap_sort<THeap: PartialOrd>(ptr: *mut THeap, left: isize, right: isize) {
            let len = right - left + 1;
            if len > 1 {
                let ls = len >> 1;
                let mut i = ls;
                unsafe {
                    let ptr = ptr.offset(left);
                    while i > 0 {
                        i -= 1;
                        let mut n = i;
                        let mut j: THeap = mem::uninitialized();
                        ptr::copy_nonoverlapping(ptr.offset(i), &mut j, 1);
                        while n < ls {
                            let mut l1 = (n << 1) + 1;
                            if l1 + 1 < len && *(ptr.offset(l1)) < *(ptr.offset(l1 + 1)) {
                                l1 += 1;
                            }
                            if j >= *(ptr.offset(l1)) {
                                break;
                            }
                            ptr::copy_nonoverlapping(ptr.offset(l1), ptr.offset(n), 1);
                            n = l1;
                        }
                        ptr::copy_nonoverlapping(&j, ptr.offset(n), 1);
                        mem::forget(j);
                    }
                    let mut i = len;
                    while i > 0 {
                        i -= 1;
                        let mut j: THeap = mem::uninitialized();
                        ptr::copy_nonoverlapping(ptr.offset(i), &mut j, 1);
                        ptr::copy_nonoverlapping(ptr, ptr.offset(i), 1);

                        let mut n = 0;
                        let mut leaf = 1;
                        while leaf < i {
                            if leaf + 1 < i && *(ptr.offset(leaf)) < *(ptr.offset(leaf + 1)) {
                                leaf += 1;
                            }
                            if j >= *(ptr.offset(leaf)) {
                                break;
                            }
                            ptr::copy_nonoverlapping(ptr.offset(leaf), ptr.offset(n), 1);
                            n = leaf;
                            leaf = (n << 1) + 1;
                        }
                        ptr::copy_nonoverlapping(&j, ptr.offset(n), 1);
                        mem::forget(j);
                    }
                }
            }
        }

        if depth <= 0 {
            heap_sort(ptr, left, right);
        }

        const INSERT_THRESHOLD: isize = 32;
        let mut l = left;
        let mut r = right;
        unsafe {
            let mut pivot: TInner = mem::uninitialized();
            ptr::copy_nonoverlapping(ptr.offset(((left + right) >> 1)), &mut pivot, 1);
            let mut tmp: TInner = mem::uninitialized();
            while l <= r {
                while pivot < *(ptr.offset(r)) && r > left {
                    r -= 1;
                }
                while *(ptr.offset(l)) < pivot && l < right {
                    l += 1;
                }
                if l <= r {
                    ptr::copy_nonoverlapping(ptr.offset(l), &mut tmp, 1);
                    ptr::copy_nonoverlapping(ptr.offset(r), ptr.offset(l), 1);
                    ptr::copy_nonoverlapping(&mut tmp, ptr.offset(r), 1);
                    if r > 0 {
                        r -= 1;
                    }
                    l += 1;
                }
            }
            if r > INSERT_THRESHOLD + left {
                q_sort(ptr, left, r, depth - 1);
            }
            if right > INSERT_THRESHOLD + l {
                q_sort(ptr, l, right, depth - 1);
            }
            mem::forget(pivot);
            mem::forget(tmp);
        }
    }

    fn insert_sort<TInner: PartialOrd>(source: &mut [TInner]) {
        let ptr = source.as_mut_ptr();
        for i in 1..source.len() {
            let mut j = i;
            unsafe {
                let mut t: TInner = mem::uninitialized();
                ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
                while 0 < j && *(ptr.offset((j - 1) as isize)) > t {
                    ptr::copy_nonoverlapping(
                        ptr.offset((j - 1) as isize),
                        ptr.offset(j as isize),
                        1,
                    );
                    j -= 1;
                }
                ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
                mem::forget(t);
            }
        }
    }

    let size = source.len() - 1;
    if size > 32 {
        let depth = (f32::log2(source.len() as f32) as usize) << 1;
        q_sort(source.as_mut_ptr(), 0, size as isize, depth);
    }
    insert_sort(source);
}

このイントロソート、ランダムデータに対して処理を行ったはずだが、通常のクイックソートより速い結果が得られた。