LoginSignup
0
0

More than 1 year has passed since last update.

Pythonのジェネレータを使って競プロを遊ぶ

Posted at

この記事は何

Pythonで競プロを遊ぶときにジェネレータをたまに使っているのですが、使用感が良いので紹介してみようという記事です。個人的に感じたメリットとしては、

  • メイン部分の実装が簡潔になる
  • Python特有の機能を使えている気分になれてうれしい(?)

などがあります。
Pythonでの競プロ実装する際の一方針を紹介するものであり、計算速度が改善したりするわけではないです。

動作はPython3.8.6で確認しています。
不備や不正確な点などがありましたら、ご指摘いただけるとうれしいです。

ジェネレータとは?

pythonのジェネレータに関してざっくり説明します。
詳しい説明はDocumentationや他の方の記事を参照いただければと思います。

関数のreturnyieldに置き換えるとジェネレータになります。
関数は値をreturnした時点で処理を終了します。一方、ジェネレータでは値をyieldした時点の状態が保持され、再びジェネレータが呼び出されるとその続きから処理が再開されます。
関数がを返すのに対して、ジェネレータは一連のデータをひとつずつ返すオブジェクトであるイテレータを返してくれます。

全競プロpythonistaが使っているmap()もジェネレータで同じ機能が実装できます。

def mymap(func, value):
    for v in value:
        yield func(v)


inp = "1 1 2 3 5 8"
A = list(mymap(int, inp.split()))
B = list(map(int, inp.split()))
print(A)    # -> [1, 1, 2, 3, 5, 8]
print(B)    # -> [1, 1, 2, 3, 5, 8]

返したい一連の値をジェネレータで具体的に設定できる、という感じです。

ジェネレータで定義した一連のデータは、list()でリスト化したり、max()で最大値を求めたり、後述するようにfor文で値を逐次取り出すことができます。値をひとつずつ取得したいときはnext()関数を使用します。

競プロにおける使用例

簡単な例4つほど

1. 条件分岐(if文)をいい感じにまとめる

ABC135D Digits Paradeを例に紹介します。(軽いネタバレがあります)

 
 
 
 

桁の一部が?で隠された整数$S$が与えられるので、隠された部分を数字に置き換えたときに条件を満たすものを数え上げる問題です。 
この問題では$S_i$が?のときは0 ~ 9の数字に置き換えて、$S_i$が数字のときはそのまま用いてdpテーブルを更新していくことで解くことができます。素直に実装するとこんな感じになると思います。

for i in range(len(S)):
    for j in range(13):
        if S[i] == "?":
            for k in range(10):
                dp[i+1][ ] += dp[i][k]    # dp遷移
        else:
            dp[i+1][ ] += dp[i][int(S[i])]    # dp遷移

$S_i$の値に応じて処理が少し変わるため、同じようなdp遷移を2回書くことになっていますね。
ジェネレータを用いてif文の部分をまとめると、下のようになります。

def gen(s):
    if s == "?":
        yield from range(10)
    else:
        yield int(s)


for i in range(len(S)):
    for j in range(13):
        for k in gen(S[i]):
            dp[i+1][ ] += dp[i][k]    # dp遷移

ジェネレータgenは、引数sが?のときは整数0 ~ 9をひとつずつ返してくれ、sが数字の場合はsを整数に変換したものを返してくれます。メインのforループの部分を少しすっきりさせることができましたね(実装例)。
yield from [iterableなオブジェクト]と書くことで、オブジェクトから要素をひとつずつyieldすることができます。

2. 部分集合の列挙

bitDPでたまに見る、$N$要素からなる集合の部分集合$S$の部分集合$T$を$O(3^N)$で列挙するやつです(問題例:EDPC典型90)。

cppだと以下のように書けます。シンプルでいいですね

cpp
for (int S=0; S<(1 << N); ++S) {
	for (int T=S; T>0; T=(T-1)&S) {
 		dp[S] = ...      // dp遷移
    }
}

pythonのrange()では、上コード2行目のような条件式や増減式をワンライナーで実装することはできません。while文を用いて以下のように書くことになると思います。

for S in range(1 << N):
    T = S
    while T:
        dp[S] = ...      # dp遷移
        T = (T - 1) & S

これでは頭が爆発してしまいます(断言)。
ジェネレータで条件式などを分けて実装すると、下のようになります。

def subset_of_(S):
    T = S
    while S:
        yield S
        S = (S - 1) & T


for S in range(1 << N):
    for T in subset_of_(S):
        dp[S] = ...      # dp遷移

メインのforループがシンプルになっておりdp遷移に集中して実装できそうな気がします。
ジェネレータsubset_of_はfor文を用いることで以下のように$S$の部分集合$T$を降順に列挙します。

def subset_of_(S):
    T = S
    while S:
        yield S
        S = (S - 1) & T


S = int("1101", 2)
for T in subset_of_(S):
    print(bin(T)[2:].zfill(4), T)
 
"""
出力
1101 13
1100 12
1001 9
1000 8
0101 5
0100 4
0001 1
"""

ビット演算系はジェネレータを用いてライブラリ化しやすいと思います。

3. 周期的に値を生成する

周期3の場合を考えてみます(下の実装例を見たほうが早いかもです)。
$3k$番目にa、$3k+1$番目にb、$3k+2$番目にcを取り出したい状況を考えます($k$は非負整数です)。この問題で使いたくなりました(提出)。

def gen():
    a = 1
    b = 4
    c = 9
    while True:
        yield a
        yield b
        yield c


g = gen()
print(next(g))      # -> 1
print(next(g))      # -> 4
print(next(g))      # -> 9
print(next(g))      # -> 1

for i, x in enumerate(g):
    print(x, end=", ")   # ->  4, 9, 1, 4, 9, 1, ...
    if i >= 10:
        break

g = gen()で生成したイテレータgに対して、next()を適用することでabcabc、...と繰り返し生成してくれます。yieldに到達した段階の状態が保持されるため、このような機能が実現できるわけですね。(上の例だとfor文を使うと無限ループになるので注意です。)

4. Xorshift

疑似乱数列を生成するアルゴリズムです。

def xor64(seed=88172645463325252):
    x = seed
    mask = (1 << 64) - 1
    while 1:
        x = x ^ (x << 13) & mask
        x = x ^ (x >> 7)
        x = x ^ (x << 17) & mask
        yield x


rnd = xor64()
for _ in range(5):
    print(next(rnd))
"""
出力
8748534153485358512
3040900993826735515
3453997556048239312
16431732851926010853
8204724074003728306
"""

関数を用いてxorshiftを実装しようとすると上のコードでいうxをグローバル変数にしたりする必要がありそうですが、ジェネレータの実装では不要になります。

pythonの整数はオーバーフローしないため、毎回bitmaskしているのがあまりスマートではないですね。。パフォーマンスもそこまで良くないので乱数を扱いたいならpythonのrandomモジュールを使えばいいとは思います。

余談ですが、どうしてもxorshiftを自前実装したい場合、for文で使わないのであればジェネレータよりもクロージャを使うほうがちょっとスマートかもしれません。乱数を生成するときにnext()を毎回書く必要がないので。

クロージャを用いた実装例
def xor64(seed=88172645463325252):
    x = seed
    mask = (1 << 64) - 1

    def inner():
        nonlocal x
        x = x ^ (x << 13) & mask
        x = x ^ (x >> 7)
        x = x ^ (x << 17) & mask
        return x
    return inner


rnd = xor64()
for _ in range(5):
    print(rnd())
"""
出力
8748534153485358512
3040900993826735515
3453997556048239312
16431732851926010853
8204724074003728306
"""

おわり

無限ループ、気をつけようね

参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0