rust
競技プログラミング

Rustの競プロ向け入力関数やマクロを高速化できるかやってみた

tl; dr

  • Rust 競プロer はそれぞれ独自に入力取得用のツールを自作しているのが現状
  • tubo28さんの記事「Rust の標準入出力は(何も考えないで使うと)遅い」を参考に、最近紹介されていた競プロ用の入力取得関数やマクロが高速化できるか試してみた
  • その結果、以下のように高速化できた
  • 高速化後の入力関数とマクロの速さは同程度(約9%の差)なので、どちらを使うかはお好みしだい
  • 追記: 高速化後の入力関数とマクロでは入力関数の方がわずかに速い
    • 入力関数版は BufRead::read_until() を使用し入力データを一度だけスキャンするが、入力マクロでは BufRead::read_line()str::split_whitespace() によって入力データを二度スキャンする。(なお、BufRead::read_line()BufRead::read_until() で実現されている)
    • 入力関数版では str::from_utf8() の代わりに unsafe { str::from_utf8_unchecked() } を使うことで、さらに高速化できる
    • マクロ版を書き換えて、BufRead::read_line()str::split_whitespace() の代わりに入力関数版を呼んでもいいかもしれない。

性能測定用データの作成

まず性能測定用の入力データを生成するプログラムを書きます。バイナリプロジェクトに元からある src/main.rs は削除して、src/bin ディレクトリを作ります。

$ cargo new --bin stdin-reader
$ cd stdin-reader
$ rm src/main.rs
$ mkdir src/bin

入力データは以下の形式で、データ件数は3000万件としました。

10000000     # 入力データの行数。1000万行
1 -1e3 1     # 一行に三つのデータ。型は u32, f64, String
2 -2e3 2
3 -3e3 3
...
N -Ne3 N

src/bin/data_gen.rs を作り、以下のプログラムを書きます。

src/bin/data_gen.rs
use std::io::{BufWriter, Write};

fn main() {
    // 生成する入力データの行数(10^7)
    let count = 10usize.pow(7);

    // 標準出力をロックし、BufWriterで包んでバッファリングする。
    let stdout = std::io::stdout();
    let mut out = BufWriter::new(stdout.lock());

    writeln!(out, "{}", count).expect("failed to write");
    for i in 0..count {
        writeln!(out, "{} -{}e3 {}", i, i, i).expect("failed to write");
    }

    // エラーを取りこぼさないよう明示的にflushする。
    // 参考: Rustといえどリソースの解放は注意
    //       http://keens.github.io/blog/2016/01/08/rusttoiedoriso_sunokaihouhachuui/
    out.flush().expect("failed to flush");
}

入力データを作成しましょう。

$ cargo run --release --bin data_gen > test.in

# 行数とワード数を数える
$ wc -lw test.in
 10000001 30000001 test.in

3000万 + 1個のデータを用意できました。

tubo28さんの入力関数を高速化してみる

AtCoder に登録したら解くべき精選過去問 10 問を Rust で解いてみた -- 入力関数 より

この記事では、一例として次の入力関数を使います。以降の記事中のソースコードは定義を省略します。

read 関数は標準入力から空白をスキップしてトークンを受け取り、T 型に変換して返します。エラーハンドリングを行わない競技プログラミング用の簡素なもので、読み込みと変換に失敗した場合は異常終了します。

また、問題になる場面はあまりないのですがこの入力関数は少々遅いので、速い入出力が必要になったときはおのおの解決方法を調べてほしいです

記事で紹介されている read() 関数は以下のようになります。

src/bin/main1a.rs
use std::io::*;
use std::str::FromStr;

pub fn read<T: FromStr>() -> T {
    let stdin = stdin();
    let stdin = stdin.lock();
    let token: String = stdin
        .bytes()
        .map(|c| c.expect("failed reading char") as char)
        .skip_while(|c| c.is_whitespace())
        .take_while(|c| !c.is_whitespace())
        .collect();
    token.parse().ok().expect("failed parsing")
}

元の関数を使う

read() 関数を使って性能測定用データを読み込むには以下のようにします。

src/bin/main1a.rs
fn main() {
    let count = read::<usize>();

    for i in 0..count {
        assert_eq!(i as u32, read::<u32>());
        assert_eq!(i as f64 * -1.0e3, read::<f64>());
        assert_eq!(i.to_string(), read::<String>());
    }
}

この main() 関数と先ほどの read() 関数を src/bin/main1a.rs に書いて、実行してみましょう。

$ cargo build --release --bin main1a
$ time ./target/release/main1a < test.in

高速化してみる

先ほどの read() をもう一度見て、高速化できそうな場所を探しましょう。

src/bin/main1a.rs
pub fn read<T: FromStr>() -> T {
    let stdin = stdin();
    let stdin = stdin.lock();
    let token: String = stdin
        .bytes()
        .map(|c| c.expect("failed reading char") as char)
        .skip_while(|c| c.is_whitespace())
        .take_while(|c| !c.is_whitespace())
        .collect();
    token.parse().ok().expect("failed parsing")
}

遅くなる要因として以下のものがあります。

  1. read() が呼ばれるごとに stdin.lock() を取得している
  2. 余分な変換がある。stdin からバイト列を取得 → char 型(4バイト)に変換 → 次の whitespace まで読み進む → String 型(UTF-8)に変換

1についてはプログラムの最初に stdin.lock() を1回実行するだけで済むようにしましょう。構造体を作り stdin.lock() で取得した StdinLock を格納します。

2についてですが、まず、Rustの一般的な書き方では BufRead トレイトの read_line() で一行を String として受け取り、str::split_whitespace() で分解する形になります。

しかし、str::split_whitespace() を使う場合、ライフタイムの関係で、一回の呼び出して一行分の値(複数の値)を取り出すような関数は書けますが、read() のような一回の呼び出して一つの値だけを取り出す関数は書けません。そのため今のような書き方になっていると考えられます。

この問題は hatoo@githubさんの入力マクロのように、一行分の値を取り出すマクロを定義すると回避できるのですが、ここでは、read() と同じく一回の呼び出しで一つの値だけを取り出す方法にこだわってみます。

char 型を経由せずに whitespace まで読む方法として、BufRead::read_until() が使えそうです。ただしこのメソッドは区切り文字として byte 型の値を一つしか取れません。そのため、split_whitespace() が行うような、スペース文字と改行文字のどちらか一方で区切るといったことはできません。

対策として、値の次の区切り文字がスペースの場合は reads()、改行の場合は readl() という風に二つのメソッドを定義しました。これらの関数は任意のバイトを区切り文字にできる read_until() メソッドを呼び出します。

先に main() 関数を見ましょう。上記の方針で改良した入力関数は以下のように使います。

src/bin/main1b.rs
fn main() {
    // stdinのロックを取得し、StdinReaderを作る。
    let stdin = std::io::stdin();
    let mut r = utils::StdinReader::new(stdin.lock());

    let count = r.readl::<usize>();  // 次の改行まで読む

    for i in 0..count {
        assert_eq!(i as u32, r.reads::<u32>());           // 次のスペースまで読む
        assert_eq!(i as f64 * -1.0e3, r.reads::<f64>());  // 次のスペースまで読む
        assert_eq!(i.to_string(), r.readl::<String>());   // 次の改行まで読む
    }
}

改良後のメソッドは以下のようになります。一応 utils というモジュール内に定義してみました。

src/bin/main1b.rs
mod utils {

    use std::io::BufRead;
    use std::str::{self, FromStr};

    pub struct StdinReader<R: BufRead> {
        reader: R,
        buf: Vec<u8>,
    }

    impl<R: BufRead> StdinReader<R> {
        pub fn new(reader: R) -> Self {
            Self {
                reader,
                buf: Vec::new(),
            }
        }

        // 区切り文字:スペース
        pub fn reads<T: FromStr>(&mut self) -> T {
            self.read_until(b' ')
        }

        // 区切り文字:改行
        pub fn readl<T: FromStr>(&mut self) -> T {
            self.read_until(b'\n')
        }

        pub fn read_until<T: FromStr>(&mut self, delim: u8) -> T {
            // self.bufに次のトークンをセットする
            loop {
                self.buf.clear();
                let len = self.reader
                    .read_until(delim, &mut self.buf)
                    .expect("failed reading bytes");
                match len {
                    0 => panic!("early eof"),
                    1 if self.buf[0] == delim => (), // 区切り文字だけなのでもう一度ループ
                    _ => {
                        // トークンが得られた
                        // 最後の文字が区切り文字なら削除
                        if self.buf[len - 1] == delim {
                            self.buf.truncate(len - 1);
                        }
                        break; // ループから脱出
                    }
                }
            }

            // 文字列をT型へパースする
            let elem = str::from_utf8(&self.buf).expect("invalid utf-8 string");
            elem.parse().unwrap_or_else(|_| panic!(format!("failed parsing: {}", elem)))
        }
    }

}

念のためテストケースも作成しました。utils モジュール内に tests モジュールを定義します。

src/bin/main1b.rs
mod utils {
    // StdinReaderの定義は先ほどと同じ。省略

    #[cfg(test)]
    mod tests {
        use super::StdinReader;
        use std::io;

        #[test]
        fn basics() {
            let cursor = io::Cursor::new(b"-123 456.7 Hello, world!");
            let mut reader = StdinReader::new(cursor);

            assert_eq!(-123i32, reader.reads());
            assert_eq!(456.7f64, reader.reads());
            assert_eq!("Hello, world!".to_string(), reader.readl::<String>());
        }

        #[test]
        fn edge_cases() {
            {
                let cursor = io::Cursor::new(b"8");
                let mut reader = StdinReader::new(cursor);
                assert_eq!(8u32, reader.readl());
            }
            {
                let cursor = io::Cursor::new(b"\n9");
                let mut reader = StdinReader::new(cursor);
                assert_eq!(9i32, reader.readl());
            }
            {
                let cursor = io::Cursor::new(b"\n\n10\n11");
                let mut reader = StdinReader::new(cursor);
                assert_eq!(10u8, reader.readl());
                assert_eq!(11u8, reader.readl());
            }
        }

        #[test]
        fn with_commas() {
            let cursor = io::Cursor::new(b"1,-2,3.0");
            let mut reader = StdinReader::new(cursor);
            assert_eq!(1u32, reader.read_until(b','));
            assert_eq!(-2i32, reader.read_until(b','));
            assert_eq!(3.0f64, reader.read_until(b','));
        }
    }

}

StdinReaderBufRead トレイトに対するジェネリクスとして定義したおかげで、テストの時は入力源として stdin の代わりに std::io::Cursor が使えました。

また、競プロでは必要なさそうですが、区切り文字としてカンマも使ったり、スペースを含んだ文字列 "Hello, world!" を扱ったりもできます。

テストを実行してみましょう。

$ cargo test --bin main1b
...
running 3 tests
test utils::tests::edge_cases ... ok
test utils::tests::basics ... ok
test utils::tests::with_commas ... ok

test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

追記:さらなる高速化

記事を投稿後、もう少しいじったところ、この行、

src/bin/main1b.rs
            let elem = str::from_utf8(&self.buf).expect("invalid utf-8 string");

を以下のアンセーフな関数に変更したところ、さらに高速化しました。

src/bin/main1b.rs
            let elem = unsafe { str::from_utf8_unchecked(&self.buf) };

後者は入力が UTF-8 として正しいかのチェックを行いません。競プロでは不正な入力データを考慮する必要がないため後者で十分だと考えられます。

変更後の utils モジュールを以下に示します。なお、今回の main() 関数からの使用では効果がわかりませんでしたが、念のため各メソッドに #[inline] 属性を追加しました。

src/bin/main1b.rs
mod utils {

    use std::io::BufRead;
    use std::str::{self, FromStr};

    pub struct StdinReader<R: BufRead> {
        reader: R,
        buf: Vec<u8>,
    }

    impl<R: BufRead> StdinReader<R> {
        pub fn new(reader: R) -> Self {
            Self {
                reader,
                buf: Vec::new(),
            }
        }

        #[allow(unused)]
        #[inline]
        // 区切り文字:スペース
        pub fn reads<T: FromStr>(&mut self) -> T {
            self.read_until(b' ')
        }

        #[allow(unused)]
        #[inline]
        // 区切り文字:改行
        pub fn readl<T: FromStr>(&mut self) -> T {
            self.read_until(b'\n')
        }

        #[inline]
        pub fn read_until<T: FromStr>(&mut self, delim: u8) -> T {
            // self.bufに次のトークンをセットする
            loop {
                self.buf.clear();
                let len = self.reader
                    .read_until(delim, &mut self.buf)
                    .expect("failed reading bytes");
                match len {
                    0 => panic!("early eof"),
                    1 if self.buf[0] == delim => (), // 区切り文字だけなのでもう一度ループ
                    _ => {
                        // トークンが得られた
                        // 最後の文字が区切り文字なら削除
                        if self.buf[len - 1] == delim {
                            self.buf.truncate(len - 1);
                        }
                        break; // ループから脱出
                    }
                }
            }

            // 文字列をT型へパースする
            let elem = unsafe { str::from_utf8_unchecked(&self.buf) };
            elem.parse().unwrap_or_else(|_| panic!(format!("failed parsing: {}", elem)))
        }
    }

}

hatoo@githubさんの入力マクロを高速化してみる

Rustで競技プログラミング スターターキット -- 入力 -- マクロを使ってみる より

毎回上のようなコードを書くのはしんどいので、自分はこんな感じのマクロで入力を処理しています。

こちらは一行ごとに入力を処理しています。

マクロを使った例は上のパターンに比べてかなり遅いですが $10^5$ 程度までの入力であればそれが原因でTLEになることはまずないのではないかと思います。

ここで「上のパターン」と呼んでいるのは、入力ファイル全体を一度に String に取り込んでから、個々の値に分割する方法です。それについては高速化の余地がなさそうなので本記事の対象外とし、マクロ版のみ高速化を試みました。

記事で紹介されている get!() マクロは以下のようになります。

src/bin/main2a.rs
macro_rules! get {
    ($t:ty) => {
        {
            let mut line: String = String::new();
            std::io::stdin().read_line(&mut line).unwrap();
            line.trim().parse::<$t>().unwrap()
        }
    };
    ($($t:ty),*) => {
        {
            let mut line: String = String::new();
            std::io::stdin().read_line(&mut line).unwrap();
            let mut iter = line.split_whitespace();
            (
                $(iter.next().unwrap().parse::<$t>().unwrap(),)*
            )
        }
    };
    ($t:ty; $n:expr) => {
        (0..$n).map(|_|
            get!($t)
        ).collect::<Vec<_>>()
    };
    ($($t:ty),*; $n:expr) => {
        (0..$n).map(|_|
            get!($($t),*)
        ).collect::<Vec<_>>()
    };
    ($t:ty ;;) => {
        {
            let mut line: String = String::new();
            std::io::stdin().read_line(&mut line).unwrap();
            line.split_whitespace()
                .map(|t| t.parse::<$t>().unwrap())
                .collect::<Vec<_>>()
        }
    };
    ($t:ty ;; $n:expr) => {
        (0..$n).map(|_| get!($t ;;)).collect::<Vec<_>>()
    };
}

なお、gyu-donさんが以前紹介された「Rustで競プロに使えそうな入力取得用マクロ」もhatoo@githubさんのマクロとよく似た構造になっています。

元のマクロを使う

マクロを使って性能測定用データを読み込むには以下のようにします。

src/bin/main2a.rs
fn main() {
    let count = get!(usize);

    for i in 0..count {
        let (a, b, c) = get!(u32, f64, String);
        assert_eq!(i as u32, a);
        assert_eq!(i as f64 * -1.0e3, b);
        assert_eq!(i.to_string(), c);
    }
}

高速化してみる

元のマクロですが、すでに BufRead トレイトの read_line() で一行を String として受け取り、str::split_whitespace() で分解する形になっています。そのため、BufRead::read_until() は使う必要がありません。

一方、stdin.lock() は毎回取得していますので先ほどと同じ手法で改良しましょう。


追記:記事の投稿後、入力データを、元の一行あたり三件から、一行あたり一件に減らして試したところ、マクロ版よりも関数版の方が15%ほど速いという結果が出ました。

入力関数版は BufRead::read_until() を使用し入力データを一度だけスキャンしますが、入力マクロでは BufRead::read_line()str::split_whitespace() によって入力データを二度スキャンするため前者の方が速くなると考えられます。なお、BufRead::read_line()BufRead::read_until() で実現されています。


まず使い方(main() 関数)です。

src/bin/main2b.rs
fn main() {
    // stdinのロックを取得し、StdinReaderを作る。
    let stdin = std::io::stdin();
    let mut r = StdinReader::new(stdin.lock());

    let count = get!(r, usize);

    for i in 0..count {
        let (a, b, c) = get!(r, u32, f64, String);
        assert_eq!(i as u32, a);
        assert_eq!(i as f64 * -1.0e3, b);
        assert_eq!(i.to_string(), c);
    }
}

StdinReader 構造体とマクロです。

src/bin/main2b.rs
use std::io::BufRead;

pub struct StdinReader<R: BufRead> {
    pub reader: R,
    pub buf: String,
}

impl<R: BufRead> StdinReader<R> {
    pub fn new(reader: R) -> Self {
        Self {
            reader,
            buf: String::new(),
        }
    }
}

macro_rules! get {
    ($r:expr, $t:ty) => {
        {
            let mut line = &mut $r.buf;
            line.clear();
            $r.reader.read_line(&mut line).unwrap();
            line.trim().parse::<$t>().unwrap()
        }
    };
    ($r:expr, $($t:ty),*) => {
        {
            let mut line = &mut $r.buf;
            line.clear();
            $r.reader.read_line(&mut line).unwrap();
            let mut iter = line.split_whitespace();
            (
                $(iter.next().unwrap().parse::<$t>().unwrap(),)*
            )
        }
    };
    ($r:expr, $t:ty; $n:expr) => {
        (0..$n).map(|_|
            get!($r, $t)
        ).collect::<Vec<_>>()
    };
    ($r:expr, $($t:ty),*; $n:expr) => {
        (0..$n).map(|_|
            get!($r, $($t),*)
        ).collect::<Vec<_>>()
    };
    ($r:expr, $t:ty ;;) => {
        {
            let mut line = &mut $r.buf;
            line.clear();
            $r.reader.read_line(&mut line).unwrap();
            line.split_whitespace()
                .map(|t| t.parse::<$t>().unwrap())
                .collect::<Vec<_>>()
        }
    };
    ($r:expr, $t:ty ;; $n:expr) => {
        (0..$n).map(|_| get!($r, $t ;;)).collect::<Vec<_>>()
    };
}

この改良版ですが先ほどの main() 関数でしかテストしていないため、バグがあるかもしれません。もしバグを見つけたら、コメント欄でお知らせいただくか、編集リクエストで修正版を送っていただけると助かります。

性能評価

  • Rust 1.24.1
  • FreeBSD 11.1-RELEASE (amd64)
  • Intel Core i5 7200U @ 2.50GHz、32GB RAM

性能測定用データの作成 で生成した入力データを読み込むのにかかった時間を測定しました。

$ cargo build --release --bin main1a
$ time ./target/release/main1a < test.in

それぞれ5回ずつ実行し、最高値と最低値を除いた3つの値について、算術平均を求めました。(入力データが ZFS ファイルシステムのメモリににキャッシュされた状態で実行しています)

プログラム 内容 所要時間(秒)
main1a read() 関数 13.678
main1b 改良版 reads()readl() メソッド 5.270
main1b改 main1bで str::from_utf8_unchecked() を使用 4.760
main2a get!() マクロ 6.042
main2b 改良版 get!() マクロ 4.836

追記

  • さらなる高速化 の変更を施した結果を「main1b改」として追加しました
  • 上の条件(一行あたり三件)では関数版(main1b)の方がマクロ版(main2b)よりも9%ほど遅くなりましたが、入力データを一行あたり一件に減らして試したところ、main1b のほうが main2b よりも15%ほど速くなりました。理由については次の「まとめ」を読んでください

まとめ

  • tubo28さんの入力関数 → 約2.60倍高速化できた
  • hatoo@githubさんの入力マクロ → 約1.25倍高速化できた
  • 高速化後の入力関数とマクロの速さは同程度(約9%の差)なので、どちらを使うかはお好みしだい
  • 追記: 高速化後の入力関数とマクロでは入力関数の方がわずかに速い
    • 入力関数版は BufRead::read_until() を使用し入力データを一度だけスキャンするが、入力マクロでは BufRead::read_line()str::split_whitespace() によって入力データを二度スキャンする。(なお、BufRead::read_line()BufRead::read_until() で実現されている)
    • 入力関数版では str::from_utf8() の代わりに unsafe { str::from_utf8_unchecked() } を使うことで、さらに高速化できる
    • マクロ版を書き換えて、BufRead::read_line()str::split_whitespace() の代わりに入力関数版を呼んでもいいかもしれない。