Rust
競技プログラミング

Rustで競技プログラミング スターターキット

はじめに

自分の経験から、Rustで競技プログラミングをやるときに役に立ちそうなことをまとめました

Rustで競技プログラミングを始めた方がRust特有の引っかかりどころに引っかかることなく、
より問題の本質に集中できるようになれば幸いです:muscle:

対象

プログラミング言語Rustをだいたい理解したけど競技プログラミングでRustを使うのは不安な人

環境構築

とりあえず

  • RUST_BACKTRACE=1

を環境変数に設定しておく

エディタはなんでも良いですが

  • 入力補完 (Racer + RLS)
  • 保存時か任意のタイミングで自動整形 (Rustfmt)
  • 自動コンパイル & エラー箇所の表示

を導入しておくのをおすすめします。導入方法はググれば大丈夫だと思います

環境構築が面倒な方は私が作ったエディタがありますので是非試してみてください:grinning:
面倒な設定なしに上記の機能が全部使えます
hatoo/Accepted

rustup override

rustup overrideを使うことでカレントディレクトリ以下で使うRustのバージョンを指定できます。
詳細はrustup override helpで確認してください。

$ rustup install 1.15.1
# カレントディレクトリ以下でRust 1.15.1を使う
$ rustup override set 1.15.1

Rust 1.15.1にはRustFmtとRLSがcomponentにないことに注意してください。
rustup runを使えば、rustup run nightly rustfmtなどで現在の設定がどうであれnightlyのrustfmtを呼び出せるのでエディタの設定などに活用しましょう。

サイトごとの注意

Atcoder

  • Rustのバージョンが古い(執筆時現在1.15.1)
    • いくつかの新しい機能が使えない
  • 絶対に実行時間が2ms以上かかる
    • 最速コードを狙う場合に注意

Codeforces

  • Rustのバージョンは比較的新しい
  • 32bit環境で動いている
    • usizeが32bitの大きさなので要注意

AOJ

  • Rustのバージョンが古い(執筆時現在1.17.0)

yukicoder

  • とくになし

入出力

入力

scanf()のようなものはありません。
read_to_stringread_lineで標準入力からStringを受け取り、str::split_whitespaceで分解してstr::parseでパースしていきます。

ABC037 B - 編集
https://beta.atcoder.jp/contests/abc037/tasks/abc037_b
の入力を行うコード
$N$ $Q$
$L_1$ $R_1$ $T_1$
$:$
$L_Q$ $R_Q$ $T_Q$

一度に全部読み込むパターン

エラー処理はてきとうです

use std::io::Read;

fn main() {
    let mut buf = String::new();

    // 標準入力から全部bufに読み込む
    std::io::stdin().read_to_string(&mut buf).unwrap();

    // 読み込んだStringを空白で分解する
    let mut iter = buf.split_whitespace();

    let n: usize = iter.next().unwrap().parse().unwrap();
    let q: usize = iter.next().unwrap().parse().unwrap();

    let lrt: Vec<(usize, usize, u64)> = (0..q)
        .map(|_| {
            (
                iter.next().unwrap().parse().unwrap(),
                iter.next().unwrap().parse().unwrap(),
                iter.next().unwrap().parse().unwrap(),
            )
        })
        .collect();
}

マクロを使ってみる

毎回上のようなコードを書くのはしんどいので、自分はこんな感じのマクロで入力を処理しています。
こちらは一行ごとに入力を処理しています。

macro_rules! get {
      ($t:ty) => {
          {
              let mut line: String = String::new();
              std::io::stdin().read_line(&mut line).unwrap();
              line.trim().parse::<$t>().unwrap()
          }
      };
      ($($t:ty),*) => {
          {
              let mut line: String = String::new();
              std::io::stdin().read_line(&mut line).unwrap();
              let mut iter = line.split_whitespace();
              (
                  $(iter.next().unwrap().parse::<$t>().unwrap(),)*
              )
          }
      };
      ($t:ty; $n:expr) => {
          (0..$n).map(|_|
              get!($t)
          ).collect::<Vec<_>>()
      };
      ($($t:ty),*; $n:expr) => {
          (0..$n).map(|_|
              get!($($t),*)
          ).collect::<Vec<_>>()
      };
      ($t:ty ;;) => {
          {
              let mut line: String = String::new();
              std::io::stdin().read_line(&mut line).unwrap();
              line.split_whitespace()
                  .map(|t| t.parse::<$t>().unwrap())
                  .collect::<Vec<_>>()
          }
      };
      ($t:ty ;; $n:expr) => {
          (0..$n).map(|_| get!($t ;;)).collect::<Vec<_>>()
      };
}

fn main() {
    // 一行読み込み、空白で分解してusize2つパースする
    let (n, q) = get!(usize, usize);
    // usize usize u64 をq行読む Vec<(usize, usize,u64)> 
    let lrt = get!(usize, usize, u64; q);
}

マクロを使った例は上のパターンに比べてかなり遅いですが
$10^5$程度までの入力であればそれが原因でTLE (Time Limit Exceeded)になることはまずないのではないかと思います。

ベンチマーク

最後にそれぞれのやり方で入力を処理したときの実行時間を計測したので載せておきます。環境はWSL上のDebianです。

計測方法

yes 1 | head -n [行数] | time ./[実行ファイル]

ベンチマークに使ったコード

一度に全部読み込むコード

use std::io::Read;

fn main() {
    let mut buf = String::new();
    std::io::stdin().read_to_string(&mut buf).unwrap();

    // 条件を同じにするため一度Vecに貯める
    let v = buf.split_whitespace()
        .map(|x| x.parse::<u32>().unwrap())
        .collect::<Vec<_>>();

    // 合計を出力
    println!("{}", v.into_iter().sum::<u32>());
}

マクロを使ったコード

// マクロの定義は省略
fn main() {
    println!("{}", get!(u32; 10000000).into_iter().sum::<u32>());
}

結果

一回計測しただけなのでおおまかな参考程度に

入力行数 一度に全部読み込む(秒) 上のマクロを使う(秒)
$10^5$ 0.01 0.02
$10^6$ 0.04 0.11
$10^7$ 0.29 1.08

参考: Rust の標準入出力は(何も考えないで使うと)遅い - Qiita

出力

基本的にprintln!()で良いですが出力の行数が多い場合は BufWriter を使いましょう。使い方はベンチマークを参照。
浮動小数点を出力する場合も桁数を気にせずにprintln!("{}", x)のように書いて大丈夫です。

ベンチマーク

今回はオンライン上で実行しました。
10^5行"yes"と出力したときの実行時間を計測しました。

println!

fn main() {
    for _ in 0..100_000 {
        println!("yes");
    }
}

BufWriter

use std::io::{stdout, BufWriter, Write};

fn main() {
    let out = stdout();
    let mut out = BufWriter::new(out.lock());
    for _ in 0..100_000 {
        writeln!(out, "yes").unwrap();
    }
}

結果

一回計測しただけなのでおおまかな参考程度に

Atcoder(ms) Codeforces(ms) Yukicoder(ms)
println! 160 421 164
BufWriter 4 15 4

理由はわからないですがCodeforcesだと特に差が大きいですね。
Codeforcesは最後にシステムテストがあるので、出力行数が多いときはとりあえずBufWriterを使うのが良いと思います。

参考: Rustで高速な標準出力 | κeenのHappy Hacκing Blog

おまけ

fn with_bufwriter<F: FnOnce(BufWriter<StdoutLock>) -> ()>(f: F) {
        let out = stdout();
        let writer = BufWriter::new(out.lock());
        f(writer)
}

みたいなのをスニペットに登録しておけば便利かもしれません。

デバッグ出力

追記: 2019/01/18
Rust 1.32でdbgマクロがstableに入りましたのでそちらを使うのもよいです

printデバッグをする場合、以下のようなスニペットがあると便利かもしれません

macro_rules! debug {
      ($($a:expr),*) => {
          eprintln!(concat!($(stringify!($a), " = {:?}, "),*), $($a),*);
      }
  }
fn main() {
    let x = vec![1, 2, 3];
    let y = vec![4, 5, 6];
    // x = [1, 2, 3], y = [4, 5, 6], と出力する
    debug!(x, y);
}

Entry API

HashMapBTreeMapにはEntry APIというのが実装されていています。
連想配列の操作をするときに重要なので、聞いたことがない方は軽く予習しておくと良いと思います。

逆順で〇〇するには??

逆順でソートしたかったり、小さい順に値が出てくるBinaryHeapが欲しいときがあります。
そういうときはstd::cmp::Reverseで比較を逆転させます。

use std::cmp::Reverse;
use std::collections::BinaryHeap;

fn main() {
    let mut v = vec![5, 2, 1, 4, 3];

    // 通常は昇順でソートされるが、Reverseで比較が逆になるので降順にソートされる。
    v.sort_by_key(|&x| Reverse(x));
    assert_eq!(&v, &[5, 4, 3, 2, 1]);

    // BinaryHeapは値が大きい順で出てくるmax-heapだがReverseを使ってmin-heapにできる。
    let mut min_heap: BinaryHeap<Reverse<usize>> = v.into_iter().map(Reverse).collect();

    assert_eq!(min_heap.pop(), Some(Reverse(1)));
    assert_eq!(min_heap.pop(), Some(Reverse(2)));
    assert_eq!(min_heap.pop(), Some(Reverse(3)));
    assert_eq!(min_heap.pop(), Some(Reverse(4)));
    assert_eq!(min_heap.pop(), Some(Reverse(5)));
}

しかし、std::cmp::ReverseはRust1.19からの機能なのでAtcoder等の古いRustでは使えません
自分はほぼ同じ機能のstructをスニペットに入れています。

use std::cmp::Ordering;

#[derive(Eq, PartialEq, Clone, Debug)]
pub struct Rev<T>(pub T);

impl<T: PartialOrd> PartialOrd for Rev<T> {
    fn partial_cmp(&self, other: &Rev<T>) -> Option<Ordering> {
        other.0.partial_cmp(&self.0)
    }
}

impl<T: Ord> Ord for Rev<T> {
    fn cmp(&self, other: &Rev<T>) -> Ordering {
        other.0.cmp(&self.0)
    }
}

浮動小数点の比較

例えば
以下のコードのようにf64の配列をソートしようとするとコンパイルが出来ません。

let mut v: Vec<f64> = vec![0.1, 1.0, 2.0];
// コンパイルできない!!
// error[E0277]: the trait bound `f64: std::cmp::Ord` is not satisfied
v.sort();

同じ理由でBinaryHeapも作ることが出来ません。

use std::collections::BinaryHeap;
// これもだめ
let heap: BinaryHeap<f64> = BinaryHeap::new();

エラーメッセージの通り、f64, f32はOrdトレイトを実装していないためエラーになります。

なぜ浮動小数点型にOrdトレイトを実装できないのか?

浮動小数点にはNANというのがありますが、NANとの比較はすべてfalseになります

assert_eq!(std::f64::NAN > 0.0, false);
assert_eq!(std::f64::NAN == 0.0, false);
assert_eq!(std::f64::NAN < 0.0, false);

ところで、Ordトレイトはcmpメソッドを実装することを要求しています。

pub trait Ord: Eq + PartialOrd<Self> {
    fn cmp(&self, other: &Self) -> Ordering;
    // .. 省略
}
pub enum Ordering {
    Less,
    Equal,
    Greater,
}

ここで例えばOrd::cmp(&std::f64::NAN, &0.0)がなにを返せばよいのか考えると

Ordering::Lessを返す => (NAN < 0.0) == falseなので矛盾する
Ordering::Equalを返す => (NAN == 0.0) == falseなので矛盾する
Ordering::Greater => (NAN > 0.0) == falseなので矛盾する

となって何を返しても矛盾してしまいます。よって浮動小数点にOrdトレイトを実装することは出来ません。

じゃあどうするの?

浮動小数点型もPartialOrdトレイトは実装しているので、PartialOrdトレイトを実装している型に無理やりOrdトレイトを実装するラッパstructを作ります。NANを比較しようとした場合はパニックを起こします。

// Partial orderなものをTotal orderにする
#[derive(PartialEq, PartialOrd)]
pub struct Total<T>(pub T);

impl<T: PartialEq> Eq for Total<T> {}

impl<T: PartialOrd> Ord for Total<T> {
    fn cmp(&self, other: &Total<T>) -> std::cmp::Ordering {
        self.0.partial_cmp(&other.0).unwrap()
    }
}

つかいかた

let mut v: Vec<f64> = vec![0.1, 1.0, 2.0];
// ソートできる!
v.sort_by_key(|&x| Total(x));

// f64のBinaryHeapもできる!
let heap: BinaryHeap<Total<f64>> = BinaryHeap::new();

整数のオーバーフロー

RustではDebugビルド時のみオーバーフロー、負のオーバーフローを実行時に検出します。
オンライン上で実行されるときはReleaseビルドで実行されるのでオーバーフローは検出されないので注意しましょう。

デフォルトの整数型

スタック領域を増やす

長めの再帰関数を呼び出すとスタックオーバーフローを起こすことがあります。(Codeforcesで一回経験しました)
そういうときはstd::thread::Builderを使ってスタックサイズを指定しつつ別のスレッドを作ることで解決できます。

fn main() {
    std::thread::Builder::new()
        .name("big stack size".into())
        .stack_size(32 * 1024 * 1024) // 32 MBのスタックサイズ
        .spawn(|| {
            // ここで長い再帰を実行
        })
        .unwrap()
        .join()
        .unwrap();
}

参考1: https://stackoverflow.com/a/44042122
参考2: rust - How to set the thread stack size during compile time? - Stack Overflow

二分探索について

slice::binary_searchというメソッドがありますが、

containing the index where a matching element could be inserted while maintaining sorted order

とあるように、そのインデックスに挿入したときにソート済みの状態になるように保証されているだけなので、
C++のlower_boundやupper_boundとは全然違います

例えば、同じ値が並んでいる場合どのインデックスが帰ってくるかはバージョンによっても違います

fn main() {
    let v = vec![4,4,4];
    // Rust 1.21.0         -> Ok(1)
    // Rust 1.26.0-nightly -> Ok(2)
    println!("{:?}", v.binary_search(&4) );
}

しょうがないので、lower_bound, upper_bound相当がほしいときは自分で実装しましょう
自分が実装した例はこちら
https://github.com/hatoo/competitive-rust-snippets/blob/master/src/binary_search.rs
#[snippet = ".."]みたいな行は外部ツール用なので無視してください

next_permutation

ありません。
自分はbluss/permutohedronからコピペしています。

bitset

ありません。
自分が実装した例はこちら
https://github.com/hatoo/competitive-rust-snippets/blob/master/src/bitset.rs

なんかうまいスニペットの管理方法ある?

こちらに記事を書きましたのでよければ参考にしてください:bow:
Rustで競技プログラミングをするときの"スニペット管理"をまじめに考える(cargo-snippetの紹介)