rust
競技プログラミング

AtCoder に登録したら解くべき精選過去問 10 問を Rust で解いてみた

はじめに

drken さんの素晴らしい記事で紹介されていた AtCoder に登録したら解くべき精選 10 問 を、現在お気に入りの言語である Rust で解いてみました。
扱う問題一覧はこちら https://beta.atcoder.jp/contests/abs/tasks

自分が Rust で競技を始めたときはかなり躓いてドキュメントや stackoverflow を見まくったので、そんな負担を少しでも軽減できればと思い書きました。

対象

この記事は入門書をさらっと読んだ方、または他の言語経験があって Rust の雰囲気だけでも知りたい人が対象です。

入出力

どの言語でも最初に躓くのが入力だと思います。Rust も例にもれずそうなりがちで、Rust 競プロerがそれぞれ独自に入出力用ツールを自作しているのが現状です。
この記事では、一例として次の入力関数を使います。以降の記事中のソースコードは定義を省略します。
read 関数は標準入力から空白をスキップしてトークンを受け取り、T 型に変換して返します。エラーハンドリングを行わない競技プログラミング用の簡素なもので、読み込みと変換に失敗した場合は異常終了します。
また、問題になる場面はあまりないのですがこの入力関数は少々遅いので、速い入出力が必要になったときはおのおの解決方法を調べてほしいです。

use std::io::*;
use std::str::FromStr;

fn read<T: FromStr>() -> T {
    let stdin = stdin();
    let stdin = stdin.lock();
    let token: String = stdin
        .bytes()
        .map(|c| c.expect("failed to read char") as char) 
        .skip_while(|c| c.is_whitespace())
        .take_while(|c| !c.is_whitespace())
        .collect();
    token.parse().ok().expect("failed to parse token")
}

出力は、println マクロを使えばあまり難しくありません。簡単な使い方は以降のコードを見たほうが早い気がします。
println も、高速ではありません。

2018/04/02 追記 :
tatsuya6502 さんが入出力高速化についてフォローしてくださいました。ありがとうございます!
Rustの競プロ向け入力関数やマクロを高速化できるかやってみた

問題の解答

問題文はごく簡単な文に要約して載せたので、厳密に知りたいときは上に貼った問題文へのリンクから全文を見てください。

第 1 問: ABC 086 A - Product (100 点)

整数 $a$, $b$ の積が偶数か判定する問題です。
3 項演算子が無い代わりに、if 文を評価することができます。したがって文より式と呼んだほうがいい場面もあります。
let a: u32 = read() は C++ 等と同様に let a = read::<u32>() とも書くことができます。前者では、a: u32 と型を明示していることから、最初に示した read の定義における T が推論されています。後の方で登場しますが、定義後の a の使われ方によって型が定まる場合は、単に let a = read() と書くこともできます。

fn main() {
    let a: u32 = read();
    let b: u32 = read();
    let ans = if (a * b) % 2 == 0 { "Even" } else { "Odd" };
    println!("{}", ans);
}

第 2 問: ABC 081 A - Placing Marbles (100 点)

文字列 $s$ に 1 がいくつ含まれるか数える問題です。
Rust の文字列の概念は少し複雑なので、文字列を扱う場面では、chars メソッドで Iterator<Item = char> を取得し、必要であればさらに Vec<char> (C++ における std::vector<char>) に変換してしまうのがおすすめです。 (Rust が複雑というより C++ がいいかげんなだけなのですが。)
String 型には into_bytes という Vec<u8> に変換するメソッドもあります。
次のコードでは chars メソッドで文字列上のイテレータを取得し、filter1 だけを残し、残った個数を出力しています。

fn main() {
    let s: String = read();
    println!("{}", s.chars().filter(|&c| c == '1').count());
}

Vec<char> を使うと次のようになります。Rust における文字列は String&str の二種類あるのですが、どちらもインデックスアクセスできない設計になっているので、したいときは Vec<char> に変換するのがおすすめです。
collect は、Vec を始めとする std::collections::VecDequestd::collections::BTreeSet などのデータ構造を構築する関数です。s の型が Vec であると明示すると、collect に型引数を渡さなくても Vec 型と推論してくれます。

fn main() {
    let s: Vec<char> = read::<String>().chars().collect();
    let mut cnt = 0;
    if s[0] == '1' { cnt += 1; }
    if s[1] == '1' { cnt += 1; }
    if s[2] == '1' { cnt += 1; }
    println!("{}", cnt);
}

第 3 問: ABC 081 B - Shift Only (200 点)

整数 $a_1, \ldots, a_n$ の中で $2$ で割り切れる回数が最小のものの、割り切れる回数を求める問題です。
この問題は、よく考えると整数を 2 進数表記したときに、最も下の桁から連続する 0 の個数の最小値を求めればよいことが分かります。
上の問題では read で読み込む値の型を明示してきましたが、今回は明示していません。明示しなかった場合は、以降のその値の使われ方で推論できる場合は推論され、できなければエラーになります。下のコードでは (0..n) の部分から整数型であることが分かるので、明示する必要がありません。

Range オブジェクト 0..nmap で割り切れる回数に写し、最小値を求める方針にします。map に渡すクロージャで u32 型の整数を読み込み、さらに、整数型には trailing_zeros という便利なメソッドがあるので、それを使って最も下の桁から連続する 0 の個数を求めます。最後に min を使ってその最小値を求めます。min はイテレータの長さが 0 だった場合も考慮して、Option<T> 型を返す設計になっていますが、今回は $n > 0$ が保証されているので unwrap で剥がします。

fn main() {
    let n = read();
    println!(
        "{}",
        (0..n)
            .map(|_| read::<u32>().trailing_zeros())
            .min()
            .unwrap()
    );
}

第 4 問: ABC 087 B - Coins (200 点)

整数 $a,b,c,n$ が与えられて、$500i+100j+50k=x$ を満たす整数 $i,j,k \ (0 \le i \le a, 0 \le j \le b, 0 \le k \le c)$ の個数を求める問題です。
多重ループが必要です。残念ながら Ruby の Array#product のようなメソッドは標準ライブラリには無いので C++ と同様に for ループを回します (外部ライブラリにはありますが、それをダウンロードしてこれるオンラインジャッジは今のところありません。) もっと良い書き方あるのかなあ。

fn main() {
    let a: u32 = read();
    let b: u32 = read();
    let c: u32 = read();
    let x: u32 = read();
    let mut ans = 0;
    for i in 0..a + 1 {
        for j in 0..b + 1 {
            for k in 0..c + 1 {
                if i * 500 + j * 100 + k * 50 == x {
                    ans += 1;
                }
            }
        }
    }
    println!("{}", ans);
}

第 5 問: ABC 083 B - Some Sums (200 点)

$s(x) := 「x$ を $10$ 進数表記したときの桁の和」と定義します。整数 $n,a,b$ が与えられて、$1 \le x \le n, a \le s(x) \le b$ を満たす $x$ の和を求める問題です。
filter は先程も使いましたね。filter で整数から桁の和に変換して条件で絞った後、sum メソッドで合計を求めます。イテレータが内包するオブジェクトの型が分かっていても、sum に型引数を与えなければいけません。
整数の桁の和を求める方法はいろいろとありますが、下のコードでは文字列に変換し、さらに各桁の文字 (char) を 0 から 9 の整数に変換して和を求めます。
a <= sum && sum <= b (セミコロンが無いことに注意) の部分で使っているように、関数とクロージャーは最後の文を式にした場合にそれが戻り値となります。

fn main() {
    let n: u32 = read();
    let a: u32 = read();
    let b: u32 = read();
    let ans = (1..n + 1)
        .filter(|x| {
            let sum = x.to_string()
                .chars()
                .map(|c| (c as u8 - b'0') as u32)
                .sum::<u32>();
            a <= sum && sum <= b
        })
        .sum::<u32>();
    println!("{}", ans);
}

第 6 問: ABC 088 B - Card Game for Two (200 点)

$a_1, \ldots, a_n$ を降順ソートし、偶数番目の和と奇数番目の和の差を求める問題です。
Vec を昇順ソートするには sort メソッドを使えばよいのですが、今回は降順ソートしたいので、引数に関数またはクロージャをとる sort_by を使います。x > y でソートしたいときには |x, y| y.cmp(x) を与えればよいです。cmp メソッドは、比較可能なオブジェクトを渡すと、比較して、y < xy == xy > x かを判定し、対応する列挙型 Ord のメンバを返します (ドキュメント参照)。sort_by はそれをもとにソートします。
enumerate は Ruby の Enumerate#each_with_index や、Python の zip(hoge, range(n)) に対応するもので、インデックスと要素への参照のタプルのイテレータを作成します。

fn main() {
    let n = read();
    let mut a: Vec<u32> = (0..n).map(|_| read()).collect();
    a.sort_by(|x, y| y.cmp(x));
    let mut alice = 0;
    let mut bob = 0;
    for (i, &x) in a.iter().enumerate() {
        if i % 2 == 0 {
            alice += x;
        } else {
            bob += x;
        }
    }
    println!("{}", alice - bob);
}

-x で昇順ソートしてもいいです。unsigned 型にはマイナス記号が付けられなくなっているので、singed 型を使います。

let mut a: Vec<i32> = (0..n).map(|_| read()).collect();
a.sort_by_key(|&x| -x);

昇順ソート後に reverse してもいいです。

let mut a: Vec<u32> = (0..n).map(|_| read()).collect();
a.sort();
a.reverse();

ここで使った sort は最悪時間計算量 $O(n \log n)$ の安定ソートアルゴリズムであることが保証されています。
より高速にソートを行いたく、安定でなくてもよい場合は sort_unstable を使います。(ただし、残念ながら現在 (2018/03/14) の AtCoder のバージョンでは使えないようです。。。)

第 7 問: ABC 085 B - Kagami Mochi (200 点)

$a_1, \ldots, a_n$ から重複する要素を取り除いたときに何個残るか求める問題です。
この問題もいろいろな方法で解けますが、一番シンプルであろう方法が、collections::BTreeSet (C++ の std::set) という集合を管理するデータ構造に突っ込んで重複を取り除くものでしょうか。

fn main() {
    use std::collections::BTreeSet;
    let n = read();
    let a: BTreeSet<u32> = (0..n).map(|_| read()).collect();
    println!("{}", a.len());
}

dedup メソッドを使うのもきれいですし、メモリ的にはこちらの方が効率的です。

let mut a: Vec<u32> = (0..n).map(|_| read()).collect();
a.sort();
a.dedup();
println!("{}", a.len());

BTreeSet の他にも HashSet もあります。前者は内包する型が Ord トレイトの実装を持つこと (比較可能であること)、後者は Hash トレイトの実装を持つこと (ハッシュ値の計算ができること) が使える要件にあります。

第 8 問: ABC 085 C - Otoshidama (300 点)

整数 $n,x$ が与えられ、$i+j+k=n, 10000i+5000j+1000k=x$ となる整数 $i,j,k$ を求める問題です。
多重ループを使います。Rust には当然 GOTO 文は存在しません。代わりに、ループのラベル付き break が使えます。下のコードでは一番外側のループに outer というラベルを付けて、答えが見つかったら抜けています。
C++ などでは解が見つかったことを表すフラグを用意するのがよいのかもしれませんが、Rust では Option 型があり、せっかくなので使ってみています。

fn main() {
    let n: i32 = read();
    let x: i32 = read();
    let mut ans = None;
    'outer: for i in 0..n + 1 {
        for j in 0..n - i + 1 {
            let k = n - i - j;
            if i * 10000 + j * 5000 + k * 1000 == x {
                ans = Some((i, j, k));
                break 'outer;
            }
        }
    }
    let (x, y, z) = ans.unwrap_or((-1, -1, -1));
    println!("{} {} {}", x, y, z);
}

第 9 問: ABC 049 C - Daydream (300 点)

文字列 $s$ が与えられます。dream, dreamer, erase, eraser を好きな順番で好きな個数だけ連結することで $s$ を作れるか判定する問題です。
解法は drken さんの記事で紹介されている貪欲法を使います。
&str には、引数に与えたパターン文字列が出現する箇所を見つけるメソッドが用意されています。文字列は基本的に Vec<char> で扱うのが楽なのですが、この問題では find メソッドを使いたいので String 型を使っています。
変数 s で「まだマッチしていない最初の文字を先頭とする文字列 (suffix) への参照」を管理し、それに対してパターン文字列が一致する最初の箇所が 0 文字目か否かで、prefix の判定をしています。

fn main() {
    let patterns: Vec<String> = ["dream", "dreamer", "erase", "eraser"]
        .iter()
        .map(|s| s.chars().rev().collect())
        .collect();
    let s: String = read::<String>().chars().rev().collect();
    let mut s = &s[..];
    let mut succeeded = true;
    while s.len() > 0 {
        let matched = patterns.iter().find(|&p| s.find(p) == Some(0));
        if let Some(p) = matched {
            s = &s[p.len()..];
        } else {
            succeeded = false;
            break;
        }
    }
    println!("{}", if succeeded { "YES" } else { "NO" });
}

2018/03/14 22:17 追記 : find メソッドを使うと計算量的によくない (未検証ですが最適化のおかげで定数倍の差で済むのか、なぜか通ってしまった) ことと、普通に Vec<char> でも書けることに気づきました。あまり変わりませんが次のようにも書けます。
2018/03/15 1:40 修正 : slicestarts_with というメソッドを知って、それを使うコードに修正しました。

fn main() {
    let patterns: Vec<Vec<char>> = ["dream", "dreamer", "erase", "eraser"]
        .iter()
        .map(|s| s.chars().rev().collect())
        .collect();
    let s: Vec<char> = read::<String>().chars().rev().collect();
    let mut s = &s[..];
    let mut succeeded = true;
    while s.len() > 0 {
        let matched = patterns.iter().find(|&p| s.starts_with(p));
        if let Some(p) = matched {
            s = &s[p.len()..];
        } else {
            succeeded = false;
            break;
        }
    }
    println!("{}", if succeeded { "YES" } else { "NO" });
}

2018/03/15 12:28 追記 : Delta114514 さんの記事を見てもっとシンプルな方法を知りました。
よくよく考えると前から eraser -> erase -> dreamer -> dream の順に、String のメソッド replace を用いて空文字列に置換すればよいですね。こちらは思いつきませんでした。
さて、Rust のイテレータには次のループを initfunc に関して一般化するメソッド fold があります。

let mut acc = init;
for x in iter {
    acc = func(acc, x);
}

すごい H 本を読んで感動する部分の一つだと思うのですが、fold を使うと多くのループ処理が非常にシンプルにかけます。この問題の解答も次のようにきれいに書けます。

fn main() {
    let patterns = ["eraser", "erase", "dreamer", "dream"];
    let s = patterns.iter().fold(read::<String>(), |s, x| s.replace(x, ""));
    println!("{}", if s.is_empty() { "YES" } else { "NO "});
}

第 10 問: ABC 086 C - Traveling (300 点)

$n$ 個の時間と 2 次元平面上の点の座標の組 $(t_1, (x_1, y_1)), \ldots, (t_n, (x_n, y_n))$ が与えられます。便宜上、単位は秒とメートルとします。旅人は $1$ 秒に $1$ m だけ軸に平行な 4 方向のうちいずれかの向きに直進しなくてはいけません。$t=0$ のとき旅人は $(0, 0)$ にいます。$1$ 秒単位で自由に移動し、各 $i$ に対し時刻 $t_i$ に $(x_i, y_i)$ の位置にいることはできるか判定する問題です。
drken さんの記事で紹介されているように、ある点とその次に訪れる点の距離と時間の関係のみに注目すればいいです。Iteratorwindows メソッドを使うと、隣り合う k 個の要素からなるブロックを前から順に走査する新しいイテレータを取得できます。all メソッドを使って、全ての隣り合う 2 つの要素が、条件を満たすか判定できます。

fn main() {
    let n = read();
    let mut v: Vec<(i32, i32, i32)> = (0..n).map(|_| (read(), read(), read())).collect();
    v.insert(0, (0, 0, 0));
    let yes = v[..].windows(2).all(|w| {
        let (t, x, y) = w[0];
        let (nt, nx, ny) = w[1];
        let time = nt - t;
        let dist = (nx - x).abs() + (ny - y).abs();
        dist <= time && time % 2 == dist % 2
    });
    println!("{}", if yes { "Yes" } else { "No" });
}

おわりに

Rust はいいぞ