読者です 読者をやめる 読者になる 読者になる

Rustで挿入ソート + 強制move outで高速化

挿入ソートは時間計算量  {O(n^{2})} のソートアルゴリズムであるが、特に入力  A の転倒数に対して  {O(\mathrm{inv}(A))} で抑えられること、また定数倍で高速なことから特定の場面で使われる場合がある。

Rustで挿入ソートを素朴に実装すると以下のようになる。

pub fn safe_insert_sort<T: Ord>(arr: &mut [T]) {
    for i in 1..arr.len() {
        let mut j = i;
        while 0 < j && arr[j-1] > arr[j] {
            arr.swap(j-1, j);
            j -= 1;
        }
    }
}

しかし、 arr[i] をswapによっていちいち配列に書き戻していたり、配列の境界チェックをしていたりと、このコードには無駄がある。そこで unsafe を用いてより高速化することを考える。ここで参考になるのが std::mem::swap実装である。

#[inline]
#[stable(feature = "rust1", since = "1.0.0")]
pub fn swap<T>(x: &mut T, y: &mut T) {
    unsafe {
        // Give ourselves some scratch space to work with
        let mut t: T = uninitialized();

        // Perform the swap, `&mut` pointers never alias
        ptr::copy_nonoverlapping(&*x, &mut t, 1);
        ptr::copy_nonoverlapping(&*y, x, 1);
        ptr::copy_nonoverlapping(&t, y, 1);

        // y and t now point to the same thing, but we need to completely
        // forget `t` because we do not want to run the destructor for `T`
        // on its value, which is still owned somewhere outside this function.
        forget(t);
    }
}

Rustでは通常、borrowされているメモリ領域を未初期化状態にする処理はできない。そのような処理は必ずしもUB(未定義動作)ではないが、自分で安全性を確認した上で unsafe をつけて書く必要がある。

上記の swapソースコードには、これらを考慮した上で問題のないコードを書くのに必要な道具が揃っている。すなわち、

である。

これを使って、より効率的な挿入ソートを書いてみたものが以下である。

pub fn unsafe_insert_sort<T: Ord>(arr: &mut [T]) {
    let len = arr.len();
    let ptr = arr.as_mut_ptr();
    for i in 1..len {
        unsafe {
            let mut j = i;
            let mut t: T = std::mem::uninitialized();
            std::ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
            while 0 < j && *(ptr.offset((j-1) as isize)) > t {
                std::ptr::copy_nonoverlapping(ptr.offset((j-1) as isize), ptr.offset(j as isize), 1);
                j -= 1;
            }
            std::ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
            std::mem::forget(t);
        }
    }
}

これでだいたい3倍ほど速くなるようだ。(ベンチマーク結果は以下を参照)

コード全体とベンチマーク結果

#![cfg_attr(test, feature(test))]

#[cfg(test)]
extern crate test;
#[cfg(test)]
extern crate rand;

pub fn safe_insert_sort<T: Ord>(arr: &mut [T]) {
    for i in 1..arr.len() {
        let mut j = i;
        while 0 < j && arr[j-1] > arr[j] {
            arr.swap(j-1, j);
            j -= 1;
        }
    }
}
pub fn unsafe_insert_sort<T: Ord>(arr: &mut [T]) {
    let len = arr.len();
    let ptr = arr.as_mut_ptr();
    for i in 1..len {
        unsafe {
            let mut j = i;
            let mut t: T = std::mem::uninitialized();
            std::ptr::copy_nonoverlapping(ptr.offset(j as isize), &mut t, 1);
            while 0 < j && *(ptr.offset((j-1) as isize)) > t {
                std::ptr::copy_nonoverlapping(ptr.offset((j-1) as isize), ptr.offset(j as isize), 1);
                j -= 1;
            }
            std::ptr::copy_nonoverlapping(&t, ptr.offset(j as isize), 1);
            std::mem::forget(t);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fmt::Debug;
    use test::Bencher;
    use rand::{XorShiftRng, Rng, SeedableRng};
    fn test_sort<T: Ord + Clone + Debug>(arr: &[T]) {
        let mut arr1 = arr.to_vec();
        let mut arr2 = arr.to_vec();
        let mut arr3 = arr.to_vec();
        arr1.sort();
        safe_insert_sort(&mut arr2);
        unsafe_insert_sort(&mut arr3);
        assert_eq!(arr1, arr2);
        assert_eq!(arr1, arr3);
    }
    #[test]
    fn test_sorts() {
        test_sort(&[1, 2, 3, 4]);
        test_sort(&[4, 2, 3, 1]);
        test_sort(&[3, 2, 3, 0]);
        test_sort(&[3, 3, 6, 2, 1, 5, 7, 3, 1, 2]);
    }
    #[bench]
    fn bench_safe_insert_sort_by_worst(b: &mut Bencher) {
        let v : Vec<u32> = (0..1000).rev().collect();
        b.iter(|| {
            safe_insert_sort(&mut v.clone());
        })
    }
    #[bench]
    fn bench_unsafe_insert_sort_by_worst(b: &mut Bencher) {
        let v : Vec<u32> = (0..1000).rev().collect();
        b.iter(|| {
            unsafe_insert_sort(&mut v.clone());
        })
    }
    #[bench]
    fn bench_safe_insert_sort_by_uniform_random(b: &mut Bencher) {
        let mut v : Vec<u32> = (0..1000).collect();
        XorShiftRng::from_seed([189522394, 1694417663, 1363148323, 4087496301]).shuffle(&mut v);
        b.iter(|| {
            safe_insert_sort(&mut v.clone());
        })
    }
    #[bench]
    fn bench_unsafe_insert_sort_by_uniform_random(b: &mut Bencher) {
        let mut v : Vec<u32> = (0..1000).collect();
        XorShiftRng::from_seed([189522394, 1694417663, 1363148323, 4087496301]).shuffle(&mut v);
        b.iter(|| {
            unsafe_insert_sort(&mut v.clone());
        })
    }
}
[package]
name = "insert-sort"
version = "0.1.0"
authors = ["Masaki Hara <ackie.h.gmai@gmail.com>"]

[dependencies]
rand = "0.3"
$ cargo bench
   Compiling insert-sort v0.1.0 
    Finished release [optimized] target(s) in 1.75 secs
     Running target/release/deps/insert_sort-0526ddc8d41f2829

running 5 tests
test tests::test_sorts ... ignored
test tests::bench_safe_insert_sort_by_uniform_random   ... bench:     527,505 ns/iter (+/- 47,327)
test tests::bench_safe_insert_sort_by_worst            ... bench:     985,459 ns/iter (+/- 65,682)
test tests::bench_unsafe_insert_sort_by_uniform_random ... bench:     176,659 ns/iter (+/- 25,409)
test tests::bench_unsafe_insert_sort_by_worst          ... bench:     326,565 ns/iter (+/- 77,140)

test result: ok. 0 passed; 0 failed; 1 ignored; 4 measured

ベンチマーク環境は Rust nightly rustc 1.18.0-nightly (5309a3e31 2017-04-03), Ubuntu 16.04.1 over VirtualBox over Windows 10, Surface Pro 2