想定読者
- Pythonを(特にAtCoderなどの競技プログラミングで)使っていて、実行速度を改善したい方
- 計算量に関する初歩的な知識を持っている(ないしは他の記事などで補いながら読める)方
- 特に「計算量は正しいはずなのにPythonの遅さのせいでTLEしてしまう!」という方
はじめに
言語によって実行速度は異なります。全く同じ処理を書いたとしてもC++では10msで実行できたのにPythonだと100ms、みたいなことはよくあります。実行時間制限が言語問わず一定の競技プログラミングにおいては、実行速度が遅い言語は不利にもなりかねません。Pythonはそんな、実行速度が遅い言語の一つであり、不利な選択肢と言われることも多いです。
しかし!
Pythonはなんと言っても書きやすい。簡潔に記述できるし読みやすい。どうにかPythonで競技プログラミングを戦いたい...ということで、少しでも処理を高速化するための知見を集めてみたいと思います。
なお、基本的には「計算量は同じだけれど、高速に処理できる書き方」を扱います。いわゆる定数倍です。計算量O(N)のアルゴリズムをO(1)に改善するといった、アルゴリズムそのものの改善はここでは扱いません。あくまで「計算量としては最善なはずなのに、Pythonが遅いせいでうまくいかない」という悩みを対象にしたいと思います。(逆に言うと、計算量そのものが不適切で実行時間制限オーバーになっているものを、ここに書かれているテクニックで改善することは恐らくできません)
今回は以下の3点について、計測結果とともに記載します。
1.最大値を更新するときはmax関数ではなくif文を利用する
2.探索済みの管理はsetではなくlistを利用する
3.二次元listの方向と走査順
4.forループの書き方別の速度比較 (2021/6/7追加)
テクニック
##1.最大値を更新するときはmax関数ではなくif文を利用する
以下は、「list Aに含まれる値の最大値」を求めるものです(※)。
最大値を更新する際はmax(今までに発見した最大値, 新しく発見した値)
と書く方も多いのではないかと思います。(※本来はmax(A)で求められる処理ですが、便宜上この形で実装しています。)
max_value = 0
for i in range(N):
new_value = A[i]
max_value = max(max_value, new_value)
print(max_value)
しかしながら上記の処理は、より原始的に見えるif
を使った以下の記法のほうが高速に処理されます。
計測結果を下に示します。
max_value = 0
for i in range(N):
new_value = A[i]
if max_value < new_value:
max_value = new_value
print(max_value)
測定内容
- list Aのサイズは10^5
- list Aの要素は、1~10^9の値をランダム生成
- 同一のlistに対して、maxを用いたプログラムとifを用いたプログラムとでそれぞれ実行して速度を計測
- 上記を100回試行、その際、各回の開始前にlistの要素をランダム再生成
- 100回の試行の平均値を結果とする
結果
バージョン | 実行時間(秒、小数点以下7桁) |
---|---|
maxを使ったバージョン | 0.0162913 |
ifを使ったバージョン | 0.0053729 |
maxを使ったほうは約3倍の時間が掛かることが分かりました。
原因の推測としては、関数を呼ぶ処理と、代入する処理が負担になっていると考えられます。
最大値の代入を行うときは、maxではなくifを利用しましょう。
##2.探索済みの管理はsetではなくlistを利用する
グラフなどの探索において「探索済みの場所を管理しておき、一度探索した場所を訪れたときには探索を止める」という実装がしばしば必要となります。
たとえば二次元の表において、マス目(i,j)を既に訪問したかを管理するようなケースが挙げられます。
訪問済みの管理は、集合型であるsetを使う方法があります。
visited = set()
while ...
i,j = 探索先を取得する処理
if (i,j) in visited: # 探索済みの場所であるかをチェック
continue # 探索済みなので探索をやめる
visited.add((i,j)) # 初めての場所なので処理を続行するとともに、この場所を探索済みに追加する
...
しかしながら、これはlistを使ったほうが高速に動作します。
visited = [[False] * W for i in range(H)] # 縦H,横Wの表とする
while ...
i,j = 探索先を取得する処理
if visited[i][j]: # 探索済みの場所であるかをチェック
continue # 探索済みなので探索をやめる
visited[i][j] = True # 初めての場所なので処理を続行するとともに、この場所を探索済みに追加する
...
ただし、これには条件があります。listの場合、あらかじめ探索範囲分の真偽表をメモリ上に持つ必要があります。探索される可能性のある範囲が膨大な場合、たとえば縦10^9 * 横10^9の場合などはlistで管理するのは現実的ではないため、set型と使い分ける必要があります。
測定内容
- 1000 * 1000のマス目に対して(0,0)のマスからスタートして、BFSで上下左右のマスに1マスずつ移動し、(999,999)のマスにたどり着いたらゴール
- 一度訪れたマスをsetないしはlistで管理
- 100回の試行の平均値を結果とする
setを使ったバージョン
def with_set():
visited = set()
q = deque([(0,0)])
while q:
i,j = q.popleft()
if (i,j) in visited:
continue
visited.add((i,j))
if i == H - 1 and j == W - 1:
break
for c in ((0,1),(0,-1),(1,0),(-1,0)):
if 0 <= i + c[0] < H and 0 <= j + c[1] < W:
if (i + c[0], j + c[1]) in visited:
continue
q.append((i + c[0], j + c[1]))
listを使ったバージョン
def with_list():
visited = [[False] * W for i in range(H)]
q = deque([(0,0)])
while q:
i,j = q.popleft()
if visited[i][j]:
continue
visited[i][j] = True
if i == H - 1 and j == W - 1:
break
for c in ((0,1),(0,-1),(1,0),(-1,0)):
if 0 <= i + c[0] < H and 0 <= j + c[1] < W:
if visited[i + c[0]][j + c[1]]:
continue
q.append((i + c[0], j + c[1]))
結果
バージョン | 実行時間(秒、小数点以下7桁) |
---|---|
setを使ったバージョン | 2.5052506 |
listを使ったバージョン | 2.0403028 |
setの場合は約1.25倍の時間が掛かりました。
原因の推測としては、setのほうはハッシュ値の算出を行うためでしょうか。
訪問済みの管理はsetではなくlistを利用しましょう。ただし、listの場合はあらかじめ探索可能性がある場所をすべてメモリ上に確保する必要があるため、探索可能な範囲が限定的である場合に限ります。
##3.二次元listの方向と走査順
「N個の要素について、各要素が2種類の状態を持つ」ことを二次元listで管理する際、Nが十分に大きい場合、N * 2のlistで管理するのと2 * Nのlistで管理するのはどちらが効率がよいのでしょうか。またそのlistはどの順番で走査するのが効率的なのでしょうか。
測定内容
- N = 10^5
- 2 * NないしはN * 2の二次元listに対してインデックスでアクセス。すべてのA[i][j]に対してi + jの値を代入する。
- 格納と走査のパターンは以下イメージ。(便宜上N = 5の図として記載)
つまり、縦をH,横をWとしたとき、A[i][j] # 0 <= i < H, 0 <= j < W
となる二次元listを生成する前提で、以下4通りを比較します。
- H = 10^5, W = 2, 二重ループの外側でiを、内側でjを回す
- H = 10^5, W = 2, 二重ループの外側でjを、内側でiを回す
- H = 2, W = 10^5, 二重ループの外側でiを、内側でjを回す
- H = 2, W = 10^5, 二重ループの外側でjを、内側でiを回す
1,2のためにlist Aを、3,4のためにlist Bを事前に準備し、初期値に0を設定します。
N = 10 ** 5
A = [[0] * 2 for i in range(N)]
B = [[0] * N for i in range(2)]
そして、1~4の4通りの走査順で、全てのA[i][j]に対してi+jの値を代入します。を以下の通り実装し、今回も100回の試行の平均を取ります。以下のf_1~f_4の関数を用意しました。
def f_1():
res = 0
for i in range(N):
for j in range(2):
A[i][j] = i + j
return res
def f_2():
res = 0
for j in range(2):
for i in range(N):
A[i][j] = i + j
return res
def f_3():
res = 0
for i in range(2):
for j in range(N):
B[i][j] = i + j
return res
def f_4():
res = 0
for j in range(N):
for i in range(2):
B[i][j] = i + j
return res
結果
バージョン | 実行時間(秒、小数点以下7桁) |
---|---|
パターン1 | 0.0353376 |
パターン2 | 0.0166063 |
パターン3 | 0.0169231 |
パターン4 | 0.0371899 |
速いパターン2種(2,3)と、そうでないパターン2種(1,4)で、2倍くらいの差がでました。
最初のイメージ図に順位(速い順)を記載したものが以下です。
数値が大きいNのほうに沿って連続で処理していくほうが速くなるように見えますね。
なお2のループの回し方(二重ループの外側で回す値をA[i][j]のjに設定)は書き方として混乱しそうなので、3の形で整理できればミスが少なくなりそうです。つまり、下のBの形式で宣言することです。
A = [[0] * 2 for i in range(N)]
B = [[0] * N for i in range(2)] # こちら
##4.forループの書き方別の速度比較
listの要素を全て参照(更新ではなく)したい場合の、以下3種類のforループの書き方を比較します。
list Aの全ての要素の値の和を求めることを題材とします。(forループを回すことが目的の便宜上の主題であるため、sum関数は使用しません)
パターン1. rangeを使用したインデックスでのアクセス
res = 0
for i in range(N):
res += A[i]
パターン2. in [list名]で要素を直接取得する
res = 0
for a in A:
res += a
パターン3. enumerateを使用する
res = 0
for index, a in enumerate(A):
res += a
測定内容
- 一次元list AのサイズNは10^5。
- listの要素を全て参照して値の総和を求める。
- Aの各要素の値は全て1
N = 10 ** 5
A = [1] * N
結果
バージョン | 実行時間(秒、小数点以下7桁) |
---|---|
パターン1 | 0.0059839 |
パターン2 | 0.0031708 |
パターン3 | 0.0054855 |
パターン2の、for a in A
の形式のものが最も高速でした。
なおパターン2では対象要素のindexを取得できないため、indexの値そのものを利用する(特定のindexの要素を更新するなど)必要がない場合であればパターン2の書き方をすることが有効そうです。
indexの値を利用する場合も、全ての要素を先頭から順に見ていく場合はパターン3のenumerate
を使う形式がパターン1のrangeに比べて若干高速でした。なおrangeはlistにアクセスする順番についてより細かい指定ができる(例:listの途中の要素からループを開始する、一つおきに要素にアクセスする、後ろから逆順にループを見ていく、等)ため、こうした指定が必要な場合にはrangeを選択するのが適切と考えられます。
更新履歴
- 3.は当初、listのすべての値を足し合わせるものでした。計測に至った意図はlistにインデックスでアクセスする場合の速度比較でしたが、その目的のもとでは不適切(参照であればインデックスでアクセスする必要がない)ため、題材を変更しました。またこれに伴い、listの各要素を参照するforループを複数パターン比較する4.を追加しました。(2021/6/7)
本記事はこちらの企画「競技プログラミング研究月間 - みんなでさらなる高みを目指そう」に参加したものです。
https://qiita.com/official-events/5a0502a2d94ed6a00c30