特に明記無い限り、本記事は2023年AtCoder言語アップデートの前の情報です
AtCoderをはじめたので、自分用スニペット集を、少しずつ記述していきます。
2021/4/18 入緑しました!
2021/5/1 記事を分割しました
2021/7/20 記事を分割しました
2021/9/28 記事が概ね完成しましたので、「更新型記事」という表示を削除しました
2021/12/19 入水しました!
2023/7/8 AtCoder言語アップデート2023を先取りして第3章を追加検証しました
- (1) 基本編(本稿)
- (2) 応用編
- 参考: 【競プロ】PyPy3で使える!Numpy代用ライブラリ
- 参考: 【競プロ専用】PythonでMultiSetを今度こそ
応用編では、以下の典型アルゴリズムを、実際の問題を解きつつ、速度測定しています。
- 累積和、いもす法、BIT、セグメント木、bit全探索、二分探索、DFS、BFS、ワーシャルフロイド法、ダイクストラ法、Union Find、強連結成分分解、最大流問題、動的計画法
0. 本記事の概要
AtCoderでは、与えられた問題の多くは、プログラム実行時間2秒の制限時間の中で解答を出力する必要があります。そのため上位者は高速なC++を使います。私はPythonを使うため、時間制約が厳しくなってしまいます。そこで、実行時間に配慮したスニペットを用意しておくことにしました。
AtCoderで実行時間が制限超過することを、「TLE」と呼びます。
各コードは、一般的な計算量オーダーを示すのみならず、実際にデータ量を変えて、AtCoderのコードテストでベンチマークを行うことを目指したいと思います。
またPython3.8だけでなく、高速なPython互換であるPyPy3でもベンチマークをします。AtCoderの問題は、生Python3.8ではTLEになりPyPy3が必須になる場合がしばしば存在するためです。そのため、スニペットで利用する文法・ライブラリは、PyPy3で通せるものに限定します。一方、PyPy3特有のクセ(処理によってはPython3.8より遅くなること)も存在します。よって、コードテストによるベンチマークによりPyPy3のクセを把握しておくことがとても重要です。
AtCoderのPython3.8は、numpyやscipyといった高速かつ高度な数値計算ライブラリを使えますので、一部の問題はPython3.8+ライブラリで取り組むことが最適な場合もあります。ただし、それらは例外的ですので、本記事ではスコープ外とします。
0.1. 基本オーバーヘッド
N = 1
という1行を、コードテストすることで、Python3.8とPyPy3の基本的な言語オーバーヘッドを測定します。以降のベンチマークは生の値を示していますが、下記の値程度の言語オーバーヘッドが含まれるものと考えてください。
Python3.8: 23ms
PyPy3: 70ms
0.2. コードテストの終了コード
コードテストはTLE等の異常を終了コードで示します。いくつかの終了コードは意味が公開されているか、コードテストにエラーメッセージが表示されて意味が推測可能です。
終了コード | 意味 |
---|---|
-1 | コンパイルエラー(Pythonでは発生しない) |
9 | TLE |
134 | インデックス範囲外 |
136 | ゼロ除算例外 |
139 | スタックオーバーフロー(の可能性あり) |
256 | その他のエラー(エラーメッセージ表示あり) |
1. 入出力
AtCoderは、問題の読み込みと解答の出力に標準入出力を使います。問題によっては数10万のパラメータを入出力する必要があります。ただし、数10万というのはあまり大きい数では無いため、それほど高速化を意識する必要はありません。
1.1. 入力処理(通常)
input()
を使って標準入力を受け取ります。ほとんどの問題では、このやり方で速度的に十分です。なお、ベンチマークは、項1.6のテストデータ生成を使って、すべてのデータを乱数ではなく0として生成し、コードテストの入力ウインドウに貼り付けて実施しています。
1.1.1. 1つの整数
N = int(input())
実行時間はオーバーヘッドを除きほぼゼロであるため、ベンチマークは省略します。
1.1.2. 複数の横並び整数を異なる変数に代入
N, M = map(int, input().split())
実行時間はオーバーヘッドを除きほぼゼロであるため、ベンチマークは省略します。
1.1.3. 複数の横並び整数をリストに代入
N = int(input()) # Nは使わない場合が多い
S = list(map(int, input().split()))
N | 10**4 | 10**5 |
---|---|---|
Python3.8 | 20ms | 48ms |
PyPy3 | 73ms | 85ms |
1.1.4. 複数の縦並び整数をリストに代入
N = int(input())
S = [int(input()) for _ in range(N)]
N | 10**4 | 10**5 |
---|---|---|
Python3.8 | 39ms | 141ms |
PyPy3 | 101ms | 172ms |
1.1.5. 複数の縦横並び整数をリストに代入
H, W = map(int, input().split())
S = [list(map(int, input().split())) for _ in range(H)]
よくある縦横並びの入力は縦が10万オーダーですので、W = 2
固定として、Hを大きくしてみます。
H | 10**4 | 10**5 |
---|---|---|
Python3.8 | 46ms | 199ms |
PyPy3 | 109ms | 212ms |
1.2. 入力処理(高速版)
sys.stdin
を使うことで入力が速くなると言われていますが、ベンチマーク結果によると、使う意味があるのは縦並びで10万オーダーの入力をする時だけだとわかります。それでも高々100msの差であるため、以下の使い勝手を考慮すると、普段は通常版を使うようにするのが懸命です。使うのは、「あと数10msの差でTLEになってしまう」という場面か、AtCoder Heuristic Contestのような「地味な時間短縮がスコアに影響する」という場面などに、限定されます。
しかしながら、実装は簡便で、以下のように冒頭に記載して、標準のinput()
を置き換えるだけです。「おまじない」として必ず記述しておくことも好みでよいと思います。
コードテストの容量限界で試せませんが、50万行くらいの入力がある問題では、450msくらい高速化します。(10万行で高々100ms、というベンチマーク結果からみても妥当)
この記述だと、ローカルテスト時でも、標準入力後に
^D
を入れて入力終了を明示的に示す必要はありません。
import sys
def input(): return sys.stdin.readline()[:-1]
NまたはH | 10**4 | 10**5 | |
---|---|---|---|
複数の横並び整数をリストに代入 | Python3.8 | 20ms | 37ms |
PyPy3 | 65ms | 89ms | |
複数の縦並び整数をリストに代入 | Python3.8 | 26ms | 58ms |
PyPy3 | 71ms | 83ms | |
複数の縦横並び整数をリストに代入 | Python3.8 | 38ms | 127ms |
PyPy3 | 80ms | 123ms |
1.3. 典型前処理
入力とあわせて、典型的な前処理を示します。ベンチマークは省略します。
1.3.1. 0-indexed化
AtCoderの問題の入力は、1からカウントする値(1-indexed)になっている場合が多いです。配列処理等には0からカウントする値(0-indexed)の方が使いやすいです。
N = int(input())
A = list(map(lambda x: int(x) - 1, input().split()))
1.3.1. ソート
N = int(input())
A = sorted(map(int, input().split()))
1.3.2. 辞書化
{要素: 出現回数}
の辞書にしたいときに便利です。Counterは辞書のサブクラスなので、辞書のメソッドはそのまま利用可能です。
N = int(input())
from collections import Counter
A = Counter(map(int, input().split()))
{要素: 出現位置のリスト}
の辞書にしたい時は、少し複雑ですが、以下のようにします。
N = int(input())
from collections import defaultdict
A = defaultdict(list)
[A[a].append(n) for n, a in enumerate(map(int, input().split()))]
1.3.3. ランレングス圧縮
(連続する要素, 連続回数)
のタプルのリストにしたいときに便利です。辞書化と異なり、同じ要素が複数回現れる可能性があり、順番も保持する必要があるため、タプルのリストを使います。
N = int(input())
from itertools import groupby
A = [(k, len(list(g))) for k, g in groupby(map(int, input().split()))]
1.3.4. 有向グラフ
{頂点: [隣接頂点, ・・・]}
という形でグラフ構造を表現します。なお、たいていの問題は頂点番号が1〜Nになっていますが、0〜N-1の方が使い勝手がよいため、番号を1減らします。
N, M = map(int, input().split())
S = [list(map(int, input().split())) for _ in range(M)]
from collections import defaultdict
adj = defaultdict(list)
for a, b in S:
adj[a-1].append(b-1)
adj: adjacentは、隣接の略称です。
1.3.5. 無向グラフ
無向グラフの場合は、有向グラフと同様な方法で、行き帰り双方を登録します。
N, M = map(int, input().split())
S = [list(map(int, input().split())) for _ in range(M)]
from collections import defaultdict
adj = defaultdict(list)
for a, b in S:
adj[a-1].append(b-1)
adj[b-1].append(a-1)
1.3.6. 距離付き有向グラフ
距離cが定義されてる場合は、{頂点: [(隣接頂点, 距離), ・・・]}
という形で表現します。
N, M = map(int, input().split())
S = [list(map(int, input().split())) for _ in range(M)]
from collections import defaultdict
adj = defaultdict(list)
for a, b, c in S:
adj[a-1].append((b-1, c))
1.3.7. リストの木構造化(トライ木)
文字列や数値列などのリスト構造が多数あった場合に、構造を木で表現することで、さまざまな木のアルゴリズムが利用可能になり、問題が解きやすくなる場合があります。この構造をトライ木と呼びます。
リストの長さは短め(〜100)、リストの数は多め(10**5オーダー)の時が多いです。
N = int(input())
S = [input() for _ in range(N)]
from collections import defaultdict
adj_with_lable = [{}] # 順方向 <ラベル -> next_>
inv_adj = [None] # 逆方向 prev_ (inv_adj[0] = None)
last_node = [None] * N # 各入力の最終ノードid
for n, s in enumerate(S):
pos = 0
for ch in s:
if ch in adj_with_lable[pos]:
pos = adj_with_lable[pos][ch]
continue
adj_with_lable[pos][ch] = len(adj_with_lable)
inv_adj.append(pos)
pos = len(adj_with_lable)
adj_with_lable.append({})
last_node[n] = pos
adj = [set(kv.values()) for kv in adj_with_lable] # トライ木の順方向隣接リスト
M = len(inv_adj) # トライ木のノード数
1.4. 入力処理(出力つき)
縦並び入力の1つ1つがクエリーとなっている問題では、入力から出力までを統合することも可能です。しかし、前処理の場合と異なり、出力まで統合すると遅くなってしまうようです。
1.4.1. 入力と出力を分離
N = int(input())
S = [int(input()) for _ in range(N)]
for s in S:
print(s)
N | 10**4 | 10**5 |
---|---|---|
Python3.8 | 45ms | 171ms |
PyPy3 | 124ms | 204ms |
1.4.2. 入力と出力を統合
N = int(input())
for _ in range(N):
print(int(input()))
N | 10**4 | 10**5 |
---|---|---|
Python3.8 | 67ms | 371ms |
PyPy3 | 146ms | 383ms |
1.5. 出力処理
1.5.1. リストを複数の横並び整数として出力
横並び出力はjoin
は不要で、*
をつけてリストを展開するだけです。
print(*S)
ベンチマークは横並び入力とほぼ同等です。
N | 10**4 | 10**5 |
---|---|---|
Python3.8 | 21ms | 54ms |
PyPy3 | 78ms | 87ms |
1.5.2. リストを複数の縦並び整数として出力
縦並び整数入力も横並びと同様に記述可能です。
print(*S, sep='\n')
N | 10**4 | 10**5 | 10**6 |
---|---|---|---|
Python3.8 | 20ms | 50ms | 254ms |
PyPy3 | 73ms | 100ms | 177ms |
1.6. テストデータ生成
さまざまなベンチマークをしたり、AtCoderコンテスト実施時点で素早く速度を確認したりするために、大量のテストデータファイルを作成するコードを準備しておきます。できたファイルを全選択コピーして、コードテストの標準入力にペーストして使います。
テストデータ生成が目的なので、ベンチマークは行いません。
1.6.1. 複数の横並び整数のテストデータ作成
import random
N = 10**5
S = [random.randrange(1000) for _ in range(N)]
with open('sample.txt', 'w') as f:
f.write(f'{N}\n')
f.write(' '.join(map(str, S)) + '\n')
1.6.2. 複数の縦並び整数のテストデータ作成
import random
N = 10**5
S = [random.randrange(1000) for _ in range(N)]
with open('sample.txt', 'w') as f:
f.write(f'{N}\n')
f.write('\n'.join(map(str, S)) + '\n')
1.6.3. 複数の縦横並び整数のテストデータ作成
import random
H = W = 10**2
S = [[random.randrange(1000) for _ in range(W)] for _ in range(H)]
with open('sample.txt', 'w') as f:
f.write(f'{H} {W}\n')
for s in S:
f.write(' '.join(map(str, s)) + '\n')
2 . 典型基本処理
文字列やリストなどの処理で、よく使うものや速度的な注意が必要なものを、まとめておきます。
2.1. ビット計算
ビット計算するための基本処理です。
2.1.1. 最大ビット桁数
数値を2進数と見た、最大ビット桁数を求めるのは、組み込み関数を使います。
i = n.bit_length()
速度測定は省略します。
2.1.2. ビットカウント
ビット1をカウントする組み込み関数はc++ではpopcount
という組み込み関数が用意されていますが、Pythonには用意されていません。以下のようにして求めます。
i = bin(n).count('1')
速度測定は省略します。
Python10からは、
bit_count()
関数がサポートされました。これは、C++のpopcount
と同様に定数倍が高速な実装になっているようです。
2.1.3. もっとも右のビット1
もっとも右のビット1は、以下のようにして求めます。
n = n & -n
速度測定は省略します。
2.1.4. ビット回転(rol, ror)
ビット回転はいろいろな書き方がありますので、一例です。
def rol(n, bit, num_bit):
return ((n << (bit % num_bit)) % (1 << num_bit)) | (n >> (-bit % num_bit))
def ror(n, bit, num_bit):
return rol(n, -bit, num_bit)
速度測定は省略します。
2.1.5. 部分集合の列挙
各ビットを集合の要素に対応させて、1なら存在する、0なら存在しない、として状態をあらわし、状態間の逐次計算を行うことをbitDPと呼びます。
このとき、もとの集合の部分集合を、高速に列挙する必要がある場合があります。すべてのビットのすべての部分集合を走査するには、通常は、$O(2^N \times 2^N) = O(4^N)$かかりますが、以下のテクニックにより、自分自身を含み空集合を含まない部分集合を降順に列挙することを、$O(3^N)$で実現することが可能です。
subbit = bit
while subbit > 0:
# 部分集合ごとの処理
subbit = (subbit - 1) & bit
速度測定は省略します。
2.1.6. 要素数kの集合の列挙
bitDPにおいて、最大要素数(=ビット数)N、要素数(=1の数)k(>0)の集合を昇順に列挙するという、魔術的なアルゴリズムです。
bit = (1 << k) - 1
while bit < 1 << N:
# 集合ごとの処理はここに挿入する
lsb = bit & -bit
bit = (lsb + bit) | (((bit & ~(lsb + bit)) // lsb) >> 1)
速度測定は省略します。
2.1.7. 補集合
bitDPの集合bitに対して、部分集合subbitの補集合は、簡単に求められます。
complement_bit = bit - subbit
速度測定は省略します。
2.2. 整数
2.2.1. 床関数と天井関数
競技プログラミングでは、整数÷整数の切り捨て、切り上げ操作が頻出です。通常、ガウス記号で表現されます。
床関数(切り捨て) $$\left\lfloor\frac{x}{y}\right\rfloor$$
天井関数(切り下げ) $$\left\lceil\frac{x}{y}\right\rceil$$
それぞれ、実数計算としては、math.floor
とmath.ceil
がありますが、整数÷整数 に限定して誤差無く高速に求めることができます。
x // y # 床関数(切り捨て)
-(-x // y) # 天井関数(切り上げ)
速度測定は省略します。
int
は「整数部のみ」という関数であるため、int(x / y)
は、正の数の場合は床関数に、負の数の場合は天井関数になることに、注意してください。
2.2.2. 四捨五入
Pythonのround関数やフォーマット識別子による丸めは、いわゆる「銀行丸め」であって、四捨五入とは異なります。厳密に四捨五入したい場合は、decimalを使う必要があります。
参考: [解決!Python]数値を四捨五入する(丸める)には(round関数/decimal.Decimalクラス)
xとyが整数の場合は、天井関数の考え方を応用することで、以下のようにx/yの四捨五入を高速かつ厳密に計算することが可能です。
(x + y // 2) // y # 四捨五入
2.2.3. 平方数の元の数(平方根)
誤差が出そうで怖いですが、x
が少なくとも64bit以内の平方数の場合は、以下にて誤差無く平方根が求められます(小数形式になりますが)。
x ** 0.5
2.2.4. 巨大整数の四則演算
Pythonの整数は、巨大数もスムーズに扱えてオーバーフローも気にしなくてよいので、C++と比べて使い勝手がよいです。しかしながら、あまりにも巨大数だと演算が遅くなるため、巨大数にならないアルゴリズムが必要です。
各演算が、どのくらいの大きさで遅くなるか、ベンチマークしてみます。
N = 10000 # 桁数
import random
# 加減算で桁数がN程度になる頃合いの乱数
A = random.randint(10 ** (N - 1), 10 ** N)
B = random.randint(10 ** (N - 1), 10 ** N)
# A, B とあわせて、乗除算で桁数がN程度になる頃合いの乱数
C = random.randint(10 ** (N // 2 - 1), 10 ** (N // 2))
D = random.randint(10 ** (N // 2 - 1), 10 ** (N // 2))
T = 10 ** 5 # AtCoderの問題の多くは、10万オーダーのループである
for _ in range(T):
A + B # ベンチマーク対象だけコメントアウトする
#A - B
#C * D
#A // C
#A % C
N | 10 | 100 | 1,000 | 10,000 | 50,000 | 100,000 | 500,000 | 1,000,000 | |
---|---|---|---|---|---|---|---|---|---|
A + B | Python3.8 | 34ms | 29ms | 44ms | 132ms | 516ms | 999ms | 4977ms | 10241ms |
PyPy3 | 158ms | 124ms | 126ms | 132ms | 131ms | 142ms | 389ms | 975ms | |
A - B | Python3.8 | 27ms | 30ms | 46ms | 135ms | 532ms | 1015ms | 5083ms | code9 |
PyPy3 | 113ms | 118ms | 117ms | 118ms | 138ms | 149ms | 395ms | 1033ms | |
C * D | Python3.8 | 33ms | 43ms | 291ms | code9 | code9 | code9 | code9 | code9 |
PyPy3 | 123ms | 165ms | 175ms | 256ms | 989ms | 3421ms | code9 | code9 | |
A // C | Python3.8 | 35ms | 55ms | 517ms | code9 | code9 | code9 | code9 | code9 |
PyPy3 | 108ms | 122ms | 135ms | 341ms | 3704ms | code9 | code9 | code9 | |
A % C | Python3.8 | 37ms | 53ms | 518ms | code9 | code9 | code9 | code9 | code9 |
PyPy3 | 110ms | 115ms | 123ms | 278ms | 3699ms | code9 | code9 | code9 |
結果を見ると、PyPy3において、加減算は10**5桁、乗除算は10**3桁が、実用的な速度の目安であることがわかります。PyPy3はPythonと比較して、四則演算でも高速化がされているようです。
64bitまるごと計算することで定数倍高速化するような問題の場合、C++を基準にしているためPyPy3ではTLEになりやすいです。しかし、この計算性能の特性を考慮して、一度に計算するbit数を大きくとることで、ACできる可能性が強まります。(bit数を大きくし過ぎるとMLEになりやすくなるため注意)
2.3. INF
Pythonのfloat('inf')は「関数」であるためオーバーヘッドがあります。そのため、無限大を何度も定義する場合は、INF = float('inf')
と変数に一度保存してコピーするだけで、速度がかなり速くなります。さらに、INF = 10 ** 16
を使う場合があります。なお、10 ** 16
は、10のべき乗で、多少の計算をしても64bit符号付き整数の限界を超えない大きな値、という意味があります。
稀に、INF = float('inf')
ではTLEになり、INF = 大きな数値
でACとなる場合もあります。しかも、その場合、大きな数値 = 10 ** 16
が正しい打ち手ではない場合もありますので、侮れません。応用編にいくつか、実例がありますので、参照ください。
# 関数
[float('inf') for _ in range(N)]
# 変数に保存
INF = float('inf')
[INF for _ in range(N)]
# 10 ** 16
INF = 10 ** 16
[INF for _ in range(N)]
N | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
float('inf') |
Python3.8 | 178ms | 1520ms | code9 |
PyPy3 | 148ms | 1095ms | 9795ms | |
INF = float('inf') |
Python3.8 | 65ms | 396ms | 3483ms |
PyPy3 | 124ms | 618ms | 5573ms | |
INF = 10 ** 16 |
Python3.8 | 65ms | 382ms | 3690ms |
PyPy3 | 110ms | 557ms | 5313ms |
この例では、float('inf')
を変数として扱う場合と、10 ** 16
を使う場合で、あまり速度差は無いようですが、応用例では10 ** 16
を使うことで明らかに速度向上する場合もあります。
2.4. リストによる配列や行列の典型処理
配列や行列は、python3.8であればnumpyを使うことで多彩な処理が可能です。しかし、pypy3を使うとnumpyが使えません。そのため、通称のリストを組み合わせて、numpyの典型処理に相当するスニペットを準備しておきます。
この項目は、別記事である「【競プロ】PyPy3で使える!Numpy代用ライブラリ」に引っ越しました。
2.5. 幾何
2変数の配列(=ベクトル)をもとに、基本的な幾何関数を示します。
特に内積、外積は、以下の公式を前提として、ベクトル同士の角度(0度、90度、180度、左右どちら回りか、といった大まかなもの)が重要となる問題で活躍します。
$$内積(dot) \quad p \cdot q = \lvert a \rvert \lvert b \rvert \cos \theta$$
$$外積(det) \quad p \times q = \lvert a \rvert \lvert b \rvert \sin \theta$$
class Pos(list):
def __add__(self, other):
return Pos(p + q for p, q in zip(self, other))
def __sub__(self, other):
return Pos(p - q for p, q in zip(self, other))
def __mul__(self, a):
return Pos(p * a for p in self)
def __truediv__(self, a):
return Pos(p / a for p in self)
def __rmul__(self, a):
return self * a
def __neg__(self):
return self * (-1)
dot = lambda p, q: sum(p * q for p, q in zip(p, q)) #内積
det = lambda p, q: p[0] * q[1] - p[1] * q[0] #外積(2次元のみ)
norm1 = lambda p: sum(abs(p) for p in p)
norm2 = lambda p: dot(p, p)
import math
arg = lambda p: math.atan2(p[1], p[0]) # 極座標の角度(ラジアン)
# 回転
rot = lambda p, arg: Pos([math.cos(arg) * p[0] - math.sin(arg) * p[1], math.sin(arg) * p[0] + math.cos(arg) * p[1]])
# 以下は頻出ではないが知っておくと便利なもの
# 多角形の面積(の2倍)
surface2 = lambda p_list: abs(sum([det(p_list[i-1], p_list[i]) for i in range(len(p_list))]))
# 線分pq上の格子点の数
line_grid = lambda p, q: math.gcd(abs(p[0] - q[0]), abs(p[1] - q[1])) + 1
速度測定は省略します。
3. 速度注意な典型処理
Python、特にPyPyにおいては、以下の3つの処理は「鬼門」です。知らずに素直に実装すると、非常に遅いためTLEしてしまいます。
本項については、AtCoder言語アップデート2023の効果を先取り検証してみました。
3.1. 長大文字列
Pythonでは文字列はイミュータブルであるため、変更操作の際には全ての文字列をコピーする必要があり、遅いです。そのため、長大な文字列操作問題では、リストやdequeの利用が推奨されます。
さらに、PyPy3の文字列処理は遅いことにも注意しましょう。文字列操作問題ではPyPy3でTLEしてPythonでACな時もあります。
# 文字列として処理
S = ''
for _ in range(N):
S += '1'
print(S)
# リストとして処理
S = []
for _ in range(N):
S.append('1')
print(''.join(S))
N | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
文字列 | Python3.8 | 167ms | 1339ms | code9 |
PyPy3 | code9 | code9 | code9 | |
リスト | Python3.8 | 123ms | 931ms | 8754ms |
PyPy3 | 132ms | 1033ms | code9 |
N | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
文字列 | Python3.11 | code9 | code9 | code9 |
PyPy3(Python3.10) | code9 | code9 | code9 | |
リスト | Python3.11 | 61ms | 548ms | 5346ms |
PyPy3(Python3.10) | 105ms | 622ms | 8731ms |
生Pythonならなんとかなっていた文字列処理は、逆に遅くなってしまいました。リスト処理の方は、全般的に高速化が図られています。
3.2. 多次元リスト
多次元リストは面白い傾向を示します。Python3.8では、リストの多次元化は速度に影響を与えず、添字計算のオーバーヘッドが効くため、次元を減らすと遅くなります。一方、PyPy3では、添字計算のオーバーヘッドはほとんど無く、かつ、リストを1次元にすると明らかに速度がアップします。PyPy3では2次元以上のリストはあまり使わない方がよいでしょう。
なお、1次元化(簡易記述)の例にあるように、 目立たない関数_
を使うことで、3次元コードと比較して違和感の無い記述が可能です。PyPy3ではほとんどオーバーヘッドが無いため、オススメです。
# 3次元
x = [[[0] * K for _ in range(M)] for _ in range(N)]
for n in range(N):
for m in range(M):
for k in range(K):
x[n][m][k] = 1
# 2次元化
x = [[0] * K for _ in range(N * M)]
for n in range(N):
for m in range(M):
for k in range(K):
x[n * M + m][k] = 1
# 1次元化
x = [0] * N * M * K
for n in range(N):
for m in range(M):
for k in range(K):
x[(n * M + m) * K + k] = 1
# 1次元化(簡易記述)
def _(n, m, k): return (n * M + m) * K + k
x = [0] * N * M * K
for n in range(N):
for m in range(M):
for k in range(K):
x[_(n, m, k)] = 1
# 1次元化するとprintデバッグ困難なので、元の次元に戻すサポート関数を準備しておく
inv_faltten = lambda a, shape: [inv_faltten(a[n: n + len(a) // shape[0]], shape[1:]) for n in range(0, len(a), len(a) // shape[0])] if len(shape) > 1 else a
N * M * K | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
3次元 | Python3.8 | 142ms | 1226ms | code9 |
PyPy3 | 85ms | 245ms | 1279ms | |
2次元 | Python3.8 | 190ms | 1677ms | code9 |
PyPy3 | 80ms | 239ms | 1198ms | |
1次元 | Python3.8 | 222ms | 1890ms | code9 |
PyPy3 | 116ms | 150ms | 778ms | |
1次元(簡易) | Python3.8 | 268ms | 2386ms | code9 |
PyPy3 | 79ms | 144ms | 781ms |
N * M * K | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
3次元 | Python3.11 | 94ms | 859ms | 8305ms |
PyPy3(Python3.10) | 67ms | 150ms | 958ms | |
2次元 | Python3.11 | 113ms | 993ms | code9 |
PyPy3(Python3.10) | 65ms | 139ms | 815ms | |
1次元 | Python3.11 | 131ms | 1221ms | code9 |
PyPy3(Python3.10) | 65ms | 112ms | 584ms | |
1次元(簡易) | Python3.11 | 159ms | 1538ms | code9 |
PyPy3(Python3.10) | 66ms | 114ms | 589ms |
綺麗に全体的に時間短縮していますが、1次元化のテクニックの有効性は変わらないようです。
3.3. 再帰
Pythonの再帰も競プロには注意です。一般的に再帰の動作は遅いですが、特にPyPy3では輪をかけて再帰は遅いです。そのため、再帰なしのアルゴリズムを書く必要性が高いのが、Python/PyPy3を競プロで選択する場合の勘所です。
応用編でも示していますが、再帰ありの方が素直な実装ができることも多いため、再帰なしで書く「シバリ」は、なかなかきついものがあります。なおかつ、再帰なしで無理やり複雑な実装をすることで逆に遅くなるケースもあります。そのため、応用編含む本記事のように、日頃からコードテストで速度を測定して、最適なスニペットを用意しておくことが重要になります。
どうしても再帰で書く必要がある場合、PyPy3ではなくPythonで通してみるのも考えましょう。
PyPy3の場合、
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
と書くことで、再帰の性能が改善する場合があります。(逆効果な場合もあります)
なお、再帰ありの場合、設定で再帰回数上限を拡張しておくことが必要です。また、メモ化再帰という、計算結果をキャッシュしておく機能を使うことで、高速化ができる場合があります。
# 1をN回加算(再帰なし)
ans = 0
for n in range(1, N + 1):
ans += 1
print(ans)
# 1をN回加算(再帰あり)
import sys
sys.setrecursionlimit(10 ** 9)
def ans(n):
if n == 0:
return 0
else:
return ans(n - 1) + 1
print(ans(N))
# 1をN回加算(メモ化再帰)
import sys
sys.setrecursionlimit(10 ** 9)
from functools import lru_cache
@lru_cache(maxsize=None)
def ans(n):
if n == 0:
return 0
else:
return ans(n - 1) + 1
print(ans(N))
N | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
再帰なし | Python3.8 | 102ms | 767ms | 7184ms |
PyPy3 | 67ms | 80ms | 199ms | |
再帰あり | Python3.8 | 688ms | code139 | code139 |
PyPy3 | 1223ms | code256 | code256 | |
メモ化再帰 | Python3.8 | 1019ms | code139 | code139 |
PyPy3 | code30720 | code30720 | code30720 |
code139はスタックオーバーフローの可能性あり、です。スタックで確保したメモリが、メモリ上限(2GB)を超えたようです。またcode256は、この場合は再帰上限超過です。PyPy3では、recursionlimitを10**6までにしか設定できないようです。code30720は、メモ化再帰のメモ領域がメモリオーバーフローしているようです。なお、この実装ですと、同じnでans()を繰り返し呼び出しているわけではありまけせんので、メモ化再帰の効果はでていません。
言語アップデート2023検証
N | 10**6 | 10**7 | 10**8 | |
---|---|---|---|---|
再帰なし | Python3.11 | 68ms | 658ms | 5464ms |
PyPy3(Python3.10) | 59ms | 69ms | 173ms | |
再帰あり | Python3.11 | 128ms | 1193ms | code9 |
PyPy3(Python3.10) | 722ms | code256 | code256 | |
メモ化再帰 | Python3.11 | 532ms | code139 | code139 |
PyPy3(Python3.10) | code256 | code256 | code256 |
全体的に時間短縮しています。特に生Pythonの再帰性能がかなり改善されたようです。PyPyが再帰苦手なのは変わらないようです。
4. 数学関数
4.1. 素数列挙
素数を列挙する、いわゆるエラトステネスのふるい法です。
4.1.1. エラトステネスのふるい法(素直な実装)
def primes(n):
sieve = [True] * (n + 1)
sieve[0] = sieve[1] = False
for i in range(2, int(n ** 0.5) + 1):
if sieve[i]:
for j in range(i * 2, n + 1, i):
sieve[j] = False
return [i for i, s in enumerate(sieve) if s]
理論的には計算量は$O(n\log\log n)$であり、オーバーヘッドを除くと、そのような結果になっていると思います。基本的に$O(n)$程度だと、N=10**8
(1億)がTLEになる目安です。
N | 10**6 | 10**7 | 10**8 |
---|---|---|---|
Python3.8 | 171ms | 1989ms | code9 |
PyPy3 | 218ms | 663ms | 8498ms |
このくらい力任せの計算だとPyPy3の高速性が活きてきます。
4.1.2. エラトステネスのふるい法(高速化)
偶数を予め「ふるい」から外しており高速化と省メモリ化を図っています。さらなる高速化も可能ですが、効果と複雑さのトレードオフがありますので、AtCoder向けのスニペットとしては、このくらいまでで十分でしょう。
def primes(n):
sieve = [True] * ((n + 1) // 2)
for i in range(1, (int(n ** 0.5) + 1) // 2):
if sieve[i]:
for j in range(i * 3 + 1, (n + 1) // 2, i * 2 + 1):
sieve[j] = False
res = [i * 2 + 1 for i, s in enumerate(sieve) if s]
res[0] = 2
return res
N | 10**6 | 10**7 | 10**8 |
---|---|---|---|
Python3.8 | 81ms | 928ms | 10049ms |
PyPy3 | 93ms | 329ms | 3796ms |
4.2. 素因数分解
試し割り法で実装します。なおcollections.defaultdict
を使うともう少しエレガントに記述できますが、速度はやや落ちるため、使っていません。
def prime_factors(n):
res = {}
i = 2
while i * i <= n:
while n % i == 0:
n = n // i
res[i] = res.get(i, 0) + 1
i += 1
if n > 1:
res[n] = res.get(n, 0) + 1
return res
for n in range(N): # ベンチマーク用
prime_factors(n)
試し割りは、1つの数に対して$O(\sqrt{n})$と言われていますが、それは素数を素因数分解する時のワーストケースであるとともに、そもそもの処理が単純であるため高速に終了します。そのため、多くの数を連続的に求めてベンチマークして、平均値を求めることで、1つの数あたりの実際の速度を求めてみました。
N | 10**4 | 10**5 | 10**6 |
---|---|---|---|
Python3.8 | 56ms(平均6μs) | 671ms(平均7μs) | code9 |
PyPy3 | 85ms(平均9μs) | 183ms(平均2μs) | 2301ms(平均2μs) |
なお、多くの数の素因数をいっきに求める場合は、最初に素数列を求めておくことで高速化できますし、SPFなどというもっと高速なアルゴリズムもあります。しかし、ベンチマークのために行ったことであり、実際にそのようなシーンは考えにくいです。少数の独立した数の素因数を求める上では、試し割りで十分といえます。SPFを使うと、むしろ、前処理するためのメモリがオーバーフローするリスクが高いです。
4.3. 約数列挙
小さい約数と大きな約数を、同時に列挙していく方法です。
約数の個数は意外と少なく、100,000以下の数の約数の総個数は、1,166,750しかありません。そのため、10万オーダーの約数問題は、全探索が可能な場合が多いです。
def divisors(n):
res_low, res_high = [], []
i = 1
while i * i <= n:
if n % i == 0:
res_low.append(i)
if i != n // i:
res_high.append(n // i)
i += 1
return res_low + res_high[::-1]
for n in range(N): # ベンチマーク用
divisors(n)
これも、1つの数に対して$O(\sqrt{n})$ですが、素因数分解と異なり、元の数を保ったまま約数をテストしていきますので、ワーストケースに寄った結果となります。
N | 10**4 | 10**5 | 10**6 |
---|---|---|---|
Python3.8 | 95ms(平均10μs) | 1971ms(平均20μs) | code9 |
PyPy3 | 90ms(平均9μs) | 451ms(平均5μs) | code9 |
なお、結果をソートすることにこだわらなければ、集合を使って、もう少しだけ速いアルゴリズムが作れますが、速度の差はわずかです。
def divisors(n):
res = set()
i = 1
while i * i <= n:
if n % i == 0:
res.add(i)
if i != n // i:
res.add(n // i)
i += 1
return res
for n in range(N): # ベンチマーク用
divisors(n)
N | 10**4 | 10**5 | 10**6 |
---|---|---|---|
Python3.8 | 90ms(平均9μs) | 1893ms(平均19μs) | code9 |
PyPy3 | 86ms(平均9μs) | 411ms(平均4μs) | 9949ms(平均10μs) |
4.4. 多変数の最大公約数(GCD)、最小公倍数(LCM)
Python3.9からmathで多変数の最大公約数、(2変数以上の)最小公倍数がサポートされました。しかし、AtCoderは2021年4月時点ではPython3.8対応であるとともに、PyPyはもっと対応版数が低いです。よって、自力で求められるようにしておく必要があります。
具体的には、math.gcd
を起点にして2変数のlcmを求め、さらにfunctools.reduce
を利用して各関数を3変数以上に拡張します。
N = int(input()) # Nは使わない
S = list(map(int, input().split()))
import functools
import math
def lcm2(x, y):
gcd = math.gcd(x, y)
return gcd * (x // gcd) * (y // gcd)
gcd = lambda l: functools.reduce(math.gcd, l)
lcm = lambda l: functools.reduce(lcm2, l)
print(gcd(S), lcm(S))
2変数のGCDはユークリッドの互除法により数値の大きさをMとして$O(\log M)$で計算可能です。同様に2変数のLCMもGCDを使うだけなので$O(\log M)$です。これらをreduceで繰り返すため、多変数のgcd、lcmは、Nを数値の個数、Mを数値の大きさとして、$O(N\log M)$で計算可能です。Mを固定とみなすと、$O(N)$となります。
テストケースでは、数値は2固定で、数値の個数をNとしました。
N | 10**5 | 10**6 | 10**7 |
---|---|---|---|
Python3.8 | 53ms | 254ms | 2333ms |
PyPy3 | 82ms | 125ms | 496ms |
4.5. 拡張ユークリッドの互助法、中国剰余定理
何らかのオフセットと周期の組合せを求める際に、拡張ユークリッドの互助法や中国剰余定理を使うと、高速に求めることができます。
sign = lambda a: 0 if a == 0 else 1 if a > 0 else -1
# 拡張ユークリッドの互助法
# ax + by = gcd(a, b)
# 参考 https://qiita.com/akebono-san/items/f00c0db99342a8d68e5d
# 負の係数に対応
def extgcd(a, b):
x, y, u, v = 1, 0, 0, 1
sign_a, sign_b = sign(a), sign(b)
a, b = abs(a), abs(b)
while b:
k = a // b
x -= k * u
y -= k * v
x, u = u, x
y, v = v, y
a, b = b, a % b
return sign_a * x, sign_b * y
# 拡張ユークリッド互助法を使いやすく
# ax + by = c, x >= 0, y >= 0
import math
def extgcd2(a, b, c):
g = math.gcd(a, b)
if c % g != 0: # 解なし
return None, None
x, y = extgcd(a, b)
x, y = x * c // g, y * c // g
a0, b0 = abs(a) // g, abs(b) // g
t0 = min((x % b0 - x) // b0, (y % a0 - y) // a0)
return x + t0 * b0, y + t0 * a0
# 中国剰余定理
# V = [(X_i, Y_i), ...]: X_i (mod Y_i)
# https://tjkendev.github.io/procon-library/python/math/chinese-remainder.html
def crt(V):
x = 0; d = 1
for X, Y in V:
g, a, b = extgcd(d, Y)
x, d = (Y * b * x + d * a * X) // g, d * (Y // g)
x %= d
return x, d
速度測定は省略します。
4.6. 凸関数の最小値3分探索
ときどき、(下に)凸関数の最小値を求める問題が出題されます。単調増加関数の特定値を求める場合は2分探索を使いますが、凸関数の最小値を求める場合は3分探索を使います。なお、2分探索は応用範囲が広いため応用編で扱います。
以下の実際の問題が、凸関数の最小値を求める簡単な例になっています。
N = int(input())
A = list(map(int, input().split()))
# 探索対称となる凸関数
def func(x):
res = 0
for a in A:
res += x + a - min(a, 2 * x)
return res / N
l, r = 0, 10 ** 9 + 1 # 探索する開区間 (l, r)
loss = 10 ** (-7) # 許容誤差の1桁下を設定
# xが整数の場合はloss=2にする
while l + loss < r:
m1 = (l * 2 + r) / 3 # 3分探索点その1 xが整数の場合は//にする
m2 = (l + r * 2) / 3 # 3分探索点その2 xが整数の場合は//にする
if func(m1) < func(m2):
r = m2 # 探索区間を(l, m2)に縮小
else:
l = m1 # 探索区間を(m1, r)に縮小
print(func(l)) # 探索した最小値
# xが整数の場合はfunc(l+1)、func(l+2)も調べてそれらの最小にする
func計算に$O(N)$、3分探索に$O(\log 10^9)$かかりますので、Nを動かした場合は$O(N)$が計算量になります。
N | 10**5 | 10**6 | 10**7 |
---|---|---|---|
Python3.8 | 4139ms | code9 | code9 |
PyPy3 | 170ms | 981ms | 9100ms |
Python3.8だとかなり遅いようです。実数計算はPyPy3の得意領域なのかもしれません。
5. 数学関数(MODあり)
べき乗、階乗、順列、組み合わせなどが、問題を解く道具になることがあります。これらは通常、とても大きな数字になるため、問題の答えとして巨大素数で割った剰余を要求される場合が多いです。
これらの組み込み関数やmathライブラリの関数の多くは剰余を考慮していませんので、剰余を考慮したアルゴリズムを作る必要があります。普通に、組み込み関数やmathライブラリの関数を適用して、出た答えから剰余を求めようとすると、TLEになるように問題が設計されている場合が多いです。
なお、巨大素数としては、MOD = 998244353
がよく使われるため、以下のベンチマークではこの値をもとに実測しました。
10**9付近の素数が使われるため、
MOD = 1000000007
などもあるようですが、$998244353 = 2^{23}\times119+1$という良い性質があるため、MOD = 998244353
がよく使われるようです。
5.1. べき乗(べき剰余)
基本であるべき乗については、組み込み関数が剰余に対応済みです。
res = pow(N, N, MOD) # ベンチマーク用
べき剰余の計算量は$O(\log n)$です。実際にものすごい巨大数であっても、素早く計算可能です。
N | 10**(10**5) | 10**(10**6) | 10**(10**7) |
---|---|---|---|
Python3.8 | 58ms | 436ms | 9588ms |
PyPy3 | 100ms | 260ms | 4999ms |
5.2. 階乗
深いネスト数で再帰にするメリットはありませんので、普通に実装します。N >= MOD
の場合、階乗と剰余の定義から答えは0になることに注意します。
$MOD! \equiv 0 ,(mod,MOD)$の事実を使うと、ウィルソンの定理が求められます。さらに、NがMODに近い際に次項の逆元を組み合わせることで、高速に階乗の剰余を求めることが可能です。)
def factorial(n):
if n >= MOD:
return 0
res = 1
for i in range(2, n + 1):
res = res * i % MOD
return res
計算量は$O(n)$です。$O(\sqrt{n} \log n)$にする方法もあるそうですが難解です。
N | 10**6 | 10**7 | 10**8 |
---|---|---|---|
Python3.8 | 140ms | 1101ms | code9 |
PyPy3 | 72ms | 126ms | 528ms |
5.3. 逆元
5.3.1. 拡張ユークリッドの互除法
nとMODが互いに素の場合、MODを法にしたnの逆元が存在します。逆元を計算することは、この後に出てくる組合わせ計算のために有用です。逆元は「拡張ユークリッドの互除法」で求めるのが効率的です。ただし次項で示すようにPython/PyPyではフェルマーの小定理を使った方が簡明です。
def modinv(a):
b, x, y = MOD, 1, 0
while b:
tmp = a // b
a -= tmp * b
x -= tmp * y
a, b, x, y = b, a, y, x
x %= MOD;
if x < 0:
x += MOD
return x
「拡張ユークリッドの互除法」で1つの逆元を計算する計算量は理論的には$O(\log n)$であり、べき乗と同様に巨大数の逆元を素早く求められます。
N | 10**(10**5) | 10**(10**6) | 10**(10**7) |
---|---|---|---|
Python3.8 | 31ms | 207ms | 7088ms |
PyPy3 | 73ms | 172ms | 3908ms |
実際に巨大数を計算できていますが、10**(10**7)で急激に遅くなっています。巨大数の割り算で遅くなっているのかもしれません。
5.3.2. フェルマーの小定理
フェルマーの小定理をもとに、逆元がべき剰余で求められます。
\begin{align}
n^{MOD-1} &\equiv 1 \,(mod\,MOD) \\
n^{MOD-2} &\equiv n^{-1} \,(mod\,MOD)
\end{align}
modinv = lambda n: pow(n, MOD - 2, MOD)
拡張ユークリッドの互除法と速度はほぼ同じですが、powの1行にコードが単純化できますので、Python/PyPyではこちらを使った方が良さそうです。
5.4. 順列
順列 $_nP_r = n(n-1)(n-2)\cdots(n-r+1)$ は、階乗のコードを少し変更して求めます。
def perm(n, r):
if n >= MOD:
return 0
res = 1
for i in range(n - r + 1, n + 1):
res = res * i % MOD
return res
print(perm(N, N // 2)) # ベンチマーク用
計算量は$O(n-r)$です。
N | 10**6 | 10**7 | 10**8 |
---|---|---|---|
Python3.8 | 83ms | 569ms | 5389ms |
PyPy3 | 74ms | 88ms | 297ms |
5.5. 組み合わせ
組み合わせは、以下の数式で表せるため、これまでのコードの複合で求められます。
$$_nC_r = \frac{n(n-1)(n-2)\cdots(n-r+1)}{r(r-1)\cdots 2\cdot 1} = \frac{_nP_r}{r!}$$
ただし、$MOD \leqq n$の場合の場合、注意が必要です。
$$_nP_r \equiv 0 ,(mod,MOD) , \Rightarrow , _nC_r \equiv 0 ,(mod,MOD)$$
が成立するのは$MOD > n$の時だけで、$MOD \leqq n$の場合は約分により必ずしも成立しないからです。
以下では$MOD > n$を前提にしています。
def comb(n, r):
return (perm(n, r) * modinv(factorial(r))) % MOD
print(perm(N, N // 2)) # ベンチマーク用
計算量は$O(n)$です。
N | 10**6 | 10**7 | 10**8 |
---|---|---|---|
Python3.8 | 138ms | 1102ms | code9 |
PyPy3 | 70ms | 112ms | 536ms |
5.6. 階乗/逆元/順列/組み合わせの前処理方式
これまでの計算について、予め部品計算を済ませておき、1回1回の計算を超高速にする方法が有名です。組み合わせを何度も計算するような問題には有効です。
$n = 10^7 = 10,000,000$ あたりを境界として、通常方式と前処理方式の使い分けをします。ただし、$n = 10^8 = 100,000,000$ くらいになると通常方式では速度的に困難になってきます。一方、前処理方式は、nが大きくてもrが小さい場合には対応可能であり、適用できる範囲が広いようです。
5.6.1 前処理方式(フル)
階乗、逆元、逆元の階乗を、全て前処理計算します。
# 前処理
fact = [None] * (MAX_SIZE + 1)
fact_inv = [None] * (MAX_SIZE + 1)
inv = [None] * (MAX_SIZE + 1)
fact[0] = fact[1] = 1
fact_inv[0] = fact_inv[1] = 1
inv[1] = 1
for i in range(2, MAX_SIZE + 1):
fact[i] = fact[i - 1] * i % MOD
inv[i] = - inv[MOD % i] * (MOD // i) % MOD
fact_inv[i] = fact_inv[i - 1] * inv[i] % MOD
# 関数定義 n >= MOD や n < r の処理はしていません
factorial = lambda n: fact[n]
modinv = lambda n: inv[n]
factorial_modinv = lambda n: fact_inv[n]
perm = lambda n, r: factorial(n) * factorial_modinv(n - r) % MOD
comb = lambda n, r: perm(n, r) * factorial_modinv(r) % MOD
前処理の計算量は$O(n)$です。PyPy3で$10^6$くらいまでが実用範囲です。
MAX_SIZE | 10**5 | 10**6 | 10**7 |
---|---|---|---|
Python3.8 | 119ms | 1172ms | code9 |
PyPy3 | 81ms | 292ms | 4068ms |
前処理終了後の各関数の計算量は$O(1)$であり、一瞬で終わりますのでベンチマークは省略します。
5.6.2 前処理方式(簡易)
計算時間がかかる階乗のみを前処理計算します。
# 前処理
fact = [None] * (MAX_SIZE + 1)
fact[0] = fact[1] = 1
for i in range(2, MAX_SIZE + 1):
fact[i] = fact[i - 1] * i % MOD
# 関数定義 n >= MOD や n < r の処理はしていません
factorial = lambda n: fact[n]
modinv = lambda n: pow(n, MOD - 2, MOD)
perm = lambda n, r: factorial(n) * modinv(factorial(n - r)) % MOD
comb = lambda n, r: perm(n, r) * modinv(factorial(r)) % MOD
前処理の計算量は$O(n)$のままですが、実行時間は前項よりもかなり高速になります。PyPy3で$10^7$くらいまでが実用範囲です。1回あたりの計算量は前項より大きくなりますが、十分高速です。
MAX_SIZE | 10**5 | 10**6 | 10**7 |
---|---|---|---|
Python3.8 | 54ms | 296ms | 2490ms |
PyPy3 | 79ms | 115ms | 775ms |
なお、rが小さい時は、permの計算を通常方式に差し替えることで、nが大きくてもpermやcombを高速に計算することが可能です。
5.7. フィボナッチ数
フィボナッチ数の算出は単純ですが、時々関連問題が出題されるため、記載しておきます。前処理方式です。
fibo = [None] * (MAX_SIZE + 1)
fibo[0], fibo[1] = 0, 1
for n in range(2, MAX_SIZE + 1):
fibo[n] = (fibo[n - 1] + fibo[n - 2]) % MOD
前処理の計算量は$O(n)$です。PyPy3で$10^7$くらいまでが実用範囲です。
MAX_SIZE | 10**5 | 10**6 | 10**7 |
---|---|---|---|
Python3.8 | 60ms | 328ms | 2987ms |
PyPy3 | 67ms | 100ms | 789ms |
前処理終了後の計算量は$O(1)$であり、一瞬で終わりますのでベンチマークは省略します。
5.8. 繰り返し二乗法
5.8.1. 正方行列の累乗
正方行列の累乗(MOD付き)は、応用編で述べるDPで時々使う技法です。2次元リスト典型処理を前提に、特に行列積を使います。
正方行列の次数をmとすると1回の行列積は$O(m^3)$ですので、単純に積を繰り返すと全体の計算量$O(m^3 n)$になります。そこで、累乗を効率的に分解する繰り返し二乗法を使うことで、全体の計算量を$O(m^3 \log n)$に高速化します。
c++であれば、数値のべき剰余の段階で、繰り返し二乗法が登場するのですが、Pythonでは専用の関数pow()
が存在するため、繰り返し二乗法は不要でした。行列の累乗で、はじめて繰り返し二乗法が必要になります。
下記のソースは、別記事である「【競プロ】PyPy3で使える!Numpy代用ライブラリ」の関数を利用しています。
MOD = 10 ** 9 + 7
def matrix_power(x, n):
# 繰り返し二乗形式に分解
n_s = []
while n > 0:
n_s.append(n)
if n % 2 == 0:
n = n // 2
else:
n -= 1
n_s = n_s[::-1]
# 分解を組み立て直しつつ計算
pows = [None] * len(n_s)
for i, n in enumerate(n_s):
if n == 1:
res = x
elif n % 2 == 0:
res = np_matmul(pows[i - 1], pows[i - 1])
else:
res = np_matmul(pows[i - 1], x)
pows[i] = np_mod(res, MOD)
return pows[-1]
正方行列の次数は固定してnを変化させて、速度を測定しました。
n | 10**(10**2) | 10**(10**3) | 10**(10**4) |
---|---|---|---|
Python3.8 | 73ms | 441ms | 4841ms |
PyPy3 | 285ms | 516ms | 1280ms |
5.8.2. 巨大なフィボナッチ数
フィボナッチ数については、以下の漸化式が成立します。
\begin{pmatrix} Fibo(n+1) \\ Fibo(n) \end{pmatrix}
=
\begin{pmatrix} 2 & 1 \\ 1 & 1 \end{pmatrix}
\begin{pmatrix} Fibo(n-1) \\ Fibo(n-2) \end{pmatrix}
これを解くと、正方行列の累乗になります。
\begin{pmatrix} Fibo(2n+1) \\ Fibo(2n) \end{pmatrix}
=
\begin{pmatrix} 2 & 1 \\ 1 & 1 \end{pmatrix} ^n
\begin{pmatrix} Fibo(1)=1 \\ Fibo(0)=0 \end{pmatrix}
よって、これまでのコードを流用することで、巨大なフィボナッチ数が算出可能です。
ans = np_matmul(matrix_power([[2, 1], [1, 1]], N // 2), [[1], [0]])
if N % 2 == 1:
print(ans[0][0])
else:
print(ans[1][0])
計算量は$O(\log n)$です。
n | 10**(10**2) | 10**(10**3) | 10**(10**4) |
---|---|---|---|
Python3.8 | 75ms | 447ms | 4865ms |
PyPy3 | 289ms | 509ms | 1308ms |