はじめに
Rustでマラソン形式の競技プログラミングコンテストに出場したときに何度か使ったコードを紹介します。
Rustで出場できるマラソンはあまりない印象ですが、参考になれば幸いです
また、自分の記事ですがalgoの方に興味がある方は
Rustで競技プログラミング スターターキット
もしRustでのスニペット管理に興味ある方は
Rustで競技プログラミングをするときの"スニペット管理"をまじめに考える(cargo-snippetの紹介)
を御覧ください。
この記事で紹介しているスニペットはすべて、hatoo/competitive-rust-snippetsで管理しています。
PCG
2023/08/07
PCGのほうが好きになったのでPCGのセクションを足しました。
XorShiftのところは以前のまま置いておきます。
外部クレートは使えないので簡単にPCGで乱数生成器を用意します。
今回はPCGファミリのpcg32siを使います。周期32bit/内部状態32bitです。
#[repr(transparent)]
pub struct PCG32si {
state: u32,
}
impl PCG32si {
const PCG_DEFAULT_MULTIPLIER_32: u32 = 747796405;
const PCG_DEFAULT_INCREMENT_32: u32 = 2891336453;
fn pcg_oneseq_32_step_r(&mut self) {
self.state = self
.state
.wrapping_mul(Self::PCG_DEFAULT_MULTIPLIER_32)
.wrapping_add(Self::PCG_DEFAULT_INCREMENT_32);
}
fn pcg_output_rxs_m_xs_32_32(state: u32) -> u32 {
let word = ((state >> ((state >> 28).wrapping_add(4))) ^ state).wrapping_mul(277803737);
(word >> 22) ^ word
}
pub fn new(seed: u32) -> Self {
let mut rng = Self { state: seed };
rng.pcg_oneseq_32_step_r();
rng.state = rng.state.wrapping_add(seed);
rng.pcg_oneseq_32_step_r();
rng
}
pub fn next_u32(&mut self) -> u32 {
let old_state = self.state;
self.pcg_oneseq_32_step_r();
Self::pcg_output_rxs_m_xs_32_32(old_state)
}
pub fn next_f32(&mut self) -> f32 {
const FLOAT_SIZE: u32 = core::mem::size_of::<f32>() as u32 * 8;
const PRECISION: u32 = 23 + 1;
const SCALE: f32 = 1.0 / (1 << PRECISION) as f32;
const SHIFT: u32 = FLOAT_SIZE - PRECISION;
let value = self.next_u32();
let value = value >> SHIFT;
SCALE * value as f32
}
pub fn next_f32_range(&mut self, min: f32, max: f32) -> f32 {
min + (max - min) * self.next_f32()
}
}
XorShift
外部クレートは使えないので簡単にXorShiftで乱数生成器を用意します。
周期は64bitで十分だと思います(適当)。
#[derive(Debug)]
#[allow(dead_code)]
pub struct Xorshift {
seed: u64,
}
impl Xorshift {
#[allow(dead_code)]
pub fn new() -> Xorshift {
Xorshift {
seed: 0xf0fb588ca2196dac,
}
}
#[allow(dead_code)]
pub fn with_seed(seed: u64) -> Xorshift {
Xorshift { seed: seed }
}
#[inline]
#[allow(dead_code)]
pub fn next(&mut self) -> u64 {
self.seed = self.seed ^ (self.seed << 13);
self.seed = self.seed ^ (self.seed >> 7);
self.seed = self.seed ^ (self.seed << 17);
self.seed
}
#[inline]
#[allow(dead_code)]
pub fn rand(&mut self, m: u64) -> u64 {
self.next() % m
}
#[inline]
#[allow(dead_code)]
// 0.0 ~ 1.0
pub fn randf(&mut self) -> f64 {
use std::mem;
const UPPER_MASK: u64 = 0x3FF0000000000000;
const LOWER_MASK: u64 = 0xFFFFFFFFFFFFF;
let tmp = UPPER_MASK | (self.next() & LOWER_MASK);
let result: f64 = unsafe { mem::transmute(tmp) };
result - 1.0
}
}
速いハッシュ関数
2019/05/30
計測したところ、以前載せていたFNVよりrustc-hashのほうが速かったのでrustc-hashをおすすめするように修正しました
RustのHashSet
等ではデフォルトでSipHashが使われますが、セキュアな分遅いです。
rustc-hashを使うと速くなります。
短いのでrust-lang/rustc-hashからコピペすれば良いと思います。
ベンチマーク
0から順番に10^5
個の数字をinsertする時間を計測しました。
時間(ns) | |
---|---|
HashSet | 4,328,125 |
rustc-hash | 1,014,883 |
rustc-hashの方が速いですね
Interval Heap
最大値と最小値をすばやく取り出すことができるヒープです。
これを使って上位N個しか保持しない個数制限付きのヒープ(LimitedIntervalHeap)を作って、ビームサーチの情報を貯めるために使います。
ビームサーチでは評価値の上からN番目の状態までしか必要ないので、上位N個の状態しか保持しないヒープを使うことによって計算量を落とすことが出来ます。
#[derive(Clone, Debug)]
struct IntervalHeap<T: Ord + Eq> {
data: Vec<T>,
}
impl<T: Ord + Eq> IntervalHeap<T> {
#[allow(dead_code)]
fn new() -> IntervalHeap<T> {
IntervalHeap { data: Vec::new() }
}
#[allow(dead_code)]
fn with_capacity(n: usize) -> IntervalHeap<T> {
IntervalHeap {
data: Vec::with_capacity(n),
}
}
#[allow(dead_code)]
#[inline]
fn len(&self) -> usize {
self.data.len()
}
#[allow(dead_code)]
#[inline]
fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[allow(dead_code)]
#[inline]
fn push(&mut self, x: T) {
let i = self.data.len();
self.data.push(x);
self.up(i);
}
#[allow(dead_code)]
#[inline]
fn peek_min(&self) -> Option<&T> {
self.data.first()
}
#[allow(dead_code)]
#[inline]
fn peek_max(&self) -> Option<&T> {
if self.data.len() > 1 {
self.data.get(1)
} else {
self.data.first()
}
}
#[allow(dead_code)]
#[inline]
fn pop_min(&mut self) -> Option<T> {
if self.data.len() == 1 {
return self.data.pop();
}
if self.data.is_empty() {
return None;
}
let len = self.data.len();
self.data.swap(0, len - 1);
let res = self.data.pop();
self.down(0);
res
}
#[allow(dead_code)]
#[inline]
fn pop_max(&mut self) -> Option<T> {
if self.data.len() <= 2 {
return self.data.pop();
}
if self.data.is_empty() {
return None;
}
let len = self.data.len();
self.data.swap(1, len - 1);
let res = self.data.pop();
self.down(1);
res
}
#[allow(dead_code)]
#[inline]
fn parent(i: usize) -> usize {
((i >> 1) - 1) & !1
}
#[allow(dead_code)]
#[inline]
fn down(&mut self, i: usize) {
let mut i = i;
let n = self.data.len();
if i & 1 == 0 {
while (i << 1) + 2 < n {
let mut k = (i << 1) + 2;
if k + 2 < n
&& unsafe { self.data.get_unchecked(k + 2) }
< unsafe { self.data.get_unchecked(k) }
{
k = k + 2;
}
if unsafe { self.data.get_unchecked(i) } > unsafe { self.data.get_unchecked(k) } {
self.data.swap(i, k);
i = k;
if i + 1 < self.data.len()
&& unsafe { self.data.get_unchecked(i) }
> unsafe { self.data.get_unchecked(i + 1) }
{
self.data.swap(i, i + 1);
}
} else {
break;
}
}
} else {
while (i << 1) + 1 < n {
let mut k = (i << 1) + 1;
if k + 2 < n
&& unsafe { self.data.get_unchecked(k + 2) }
> unsafe { self.data.get_unchecked(k) }
{
k = k + 2;
}
if unsafe { self.data.get_unchecked(i) } < unsafe { self.data.get_unchecked(k) } {
self.data.swap(i, k);
i = k;
if i > 0
&& unsafe { self.data.get_unchecked(i) }
< unsafe { self.data.get_unchecked(i - 1) }
{
self.data.swap(i, i - 1);
}
} else {
break;
}
}
}
}
#[allow(dead_code)]
#[inline]
fn up(&mut self, i: usize) {
let mut i = i;
if i & 1 == 1
&& unsafe { self.data.get_unchecked(i) } < unsafe { self.data.get_unchecked(i - 1) }
{
self.data.swap(i, i - 1);
i -= 1;
}
while i > 1
&& unsafe { self.data.get_unchecked(i) }
< unsafe { self.data.get_unchecked(Self::parent(i)) }
{
let p = Self::parent(i);
self.data.swap(i, p);
i = p;
}
while i > 1
&& unsafe { self.data.get_unchecked(i) }
> unsafe { self.data.get_unchecked(Self::parent(i) + 1) }
{
let p = Self::parent(i) + 1;
self.data.swap(i, p);
i = p;
}
}
#[allow(dead_code)]
#[inline]
fn clear(&mut self) {
self.data.clear();
}
}
#[derive(Clone, Debug)]
struct LimitedIntervalHeap<T: Ord + Eq> {
heap: IntervalHeap<T>,
limit: usize,
}
impl<T: Ord + Eq> LimitedIntervalHeap<T> {
#[allow(dead_code)]
fn new(limit: usize) -> LimitedIntervalHeap<T> {
LimitedIntervalHeap {
heap: IntervalHeap::with_capacity(limit),
limit: limit,
}
}
#[allow(dead_code)]
#[inline]
fn is_empty(&self) -> bool {
self.heap.is_empty()
}
#[allow(dead_code)]
#[inline]
fn push(&mut self, x: T) -> Option<T> {
if self.heap.len() < self.limit {
self.heap.push(x);
None
} else {
if self.heap.data[0] < x {
let mut x = x;
std::mem::swap(&mut x, &mut self.heap.data[0]);
if self.heap.len() >= 2 && self.heap.data[0] > self.heap.data[1] {
self.heap.data.swap(0, 1);
}
self.heap.down(0);
Some(x)
} else {
Some(x)
}
}
}
#[allow(dead_code)]
#[inline]
fn pop(&mut self) -> Option<T> {
self.heap.pop_max()
}
#[allow(dead_code)]
#[inline]
fn clear(&mut self) {
self.heap.clear();
}
}
ベンチマーク
乱数を10^5
個挿入したときの時間を計測しました。
コードは https://github.com/hatoo/competitive-rust-snippets/blob/master/src/interval_heap.rs の下の方のベンチマークコードを参照してください。
時間(ns) | |
---|---|
BinaryHeap | 1,851,530 |
IntervalHeap | 2,600,608 |
LimitedIntervalHeap(50000個まで保持) | 467,430 |
BinaryHeapのかわりにIntervalHeapを使った場合は遅くなりますが、上位の値しか必要ない場合はLimitedIntervalHeapの方が早いですね。
Rcを使ったリスト
単純なリスト構造ですがRcを使って複製、値の追加をO(1)
で出来るようにしています。
ビームサーチなどの探索をする際、状態に今までの行動履歴等の、探索に従って大きくなるようなデータを保存するために使っています。
use std::rc::Rc;
#[derive(Debug)]
struct RcListInner<T> {
parent: RcList<T>,
value: T,
}
#[doc = "O(1) clone, O(1) push"]
#[derive(Clone, Debug)]
struct RcList<T>(Option<Rc<RcListInner<T>>>);
impl<T: Clone> RcList<T> {
#[allow(dead_code)]
fn new() -> Self {
RcList(None)
}
#[allow(dead_code)]
#[inline]
fn push(&mut self, value: T) {
*self = RcList(Some(Rc::new(RcListInner {
parent: self.clone(),
value,
})));
}
#[allow(dead_code)]
fn to_vec(&self) -> Vec<T> {
if let Some(ref inner) = self.0 {
let mut p = inner.parent.to_vec();
p.push(inner.value.clone());
p
} else {
Vec::new()
}
}
}