2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【競プロAtcoder】PyPyで再帰関数を用いたときにTLE/MLE/REが出てしまう時の対処法

Posted at

はじめに

AtCoderなどの競技プログラミングにPython(特にPyPy)で参加する際、再帰関数を用いた実装がTLE(Time Limit Exceeded)、MLE(Memory Limit Exceeded)、RE(Runtime Error)を引き起こすことがあります。

再帰関数によって起きたエラーに対する対処法は次のように考えられます。

  • RE(Runtime Error) → 再帰関数の再帰上限に達している可能性が高い。→ 方法1 を用いる。
  • MLE(Memory Limit Exceeded)・TLE(Time Limit Exceeded) → メモリまたは時間効率の悪いコードになっている可能性がある。→ 方法2・3 を用いる。

本記事では、これらの問題への対処法を詳しく解説します。例として次の問題を取り上げます:C - Make it Forest

REに遭遇するケース

AtCoderの問題をPythonで解く際、再帰で実装するとMLEやTLEが発生することがあります。再帰関数を用いる場面としては,DFS(深さ優先探索)を実装する場合などです.
例えば、次の提出ではREとなります:Submission #64405805 - AtCoder Beginner Contest 399

RE 実装コード
#!/usr/bin/env python3
import sys


def input():
    return sys.stdin.readline().rstrip()


N, M = map(int, input().split())
G = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    G[u].append(v)
    G[v].append(u)


def dfs(s):
    visited[s] = True
    for v in G[s]:
        if visited[v]:
            continue
        dfs(v)


n_components = 0
visited = [False] * N
for u in range(N):
    if visited[u]:
        continue
    dfs(u)
    n_components += 1

ans = M - N + n_components
print(ans)

MLE/TLEに遭遇するケース

AtCoderの問題をPythonで解く際、再帰で実装するとMLEやTLEが発生することがあります。再帰関数を用いる場面としては,DFS(深さ優先探索)を実装する場合などです.

例えば、次の提出ではMLEとなります:Submission #64405691 - AtCoder Beginner Contest 399

MLE 実装コード
#!/usr/bin/env python3
import sys

sys.setrecursionlimit(10**7)


def input():
    return sys.stdin.readline().rstrip()


N, M = map(int, input().split())
G = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    G[u].append(v)
    G[v].append(u)


def dfs(s):
    visited[s] = True
    for v in G[s]:
        if visited[v]:
            continue
        dfs(v)

n_components = 0
visited = [False] * N
for u in range(N):
    if visited[u]:
        continue
    dfs(u)
    n_components += 1

ans = M - N + n_components
print(ans)

なお,メモリ使用量を削減するために gc.collect() を挟んでも、この処理自体に時間がかかるため、TLE(Time Limit Exceeded)になることがあります。

1. sys.setrecursionlimit を適切に設定する

Pythonのデフォルトの再帰上限は低いため、再帰関数を用いる場合は、大きな数字(筆者は 10**7 を使用)を設定するのが基本です。

import sys
sys.setrecursionlimit(10**7)

例えば,先のコードにこの設定を追加するだけでREは解消されます.ただし,そのままではMLEになってしまい,ACにはなりません.Submission #64405691 - AtCoder Beginner Contest 399

2. そもそも再帰関数を使わない

再帰関数をそもそも使わない方法を考えることも重要です。例えば連結成分の判定のためにDFSを用いる場合、代わりにBFS(幅優先探索)やスタック・キューを用いた反復的(iterative)な解法に変更すると、メモリ使用量が削減できます。Union-Find(素集合データ構造)を用いても良いです.

例えば、次の提出では、再帰を用いた場合に比べてメモリ使用量が約90%削減(1093400 KB→108140 KB)され,ACを得ることができました。:Submission #64405751 - AtCoder Beginner Contest 399

以下に、非再帰的な解法の一例として、スタックを用いたグラフの連結成分数を求める実装を示します。

非再帰的な解法の実装例
#!/usr/bin/env python3
import sys

sys.setrecursionlimit(10**7)


def input():
    return sys.stdin.readline().rstrip()


N, M = map(int, input().split())

# 連結リストとしてグラフを受け取る
G = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    G[u].append(v)
    G[v].append(u)

# グラフの連結成分数を求める
# このコードは深さ優先探索とは異なることに注意
n_components = 0
visited = [False] * N
for u in range(N):
    if visited[u]:
        continue
    stack = [u]
    while stack:
        v = stack.pop()
        visited[v] = True
        for neighbor in G[v]:
            if visited[neighbor]:
                continue
            stack.append(neighbor)
    n_components += 1

ans = M - N + n_components
print(ans)

このように、スタックを用いることで、Pythonの再帰の制約を回避しつつ、メモリ消費を抑えた実装が可能です。

3. PyPyの最適化設定(pypyjit.set_param

PyPyは再帰と特に相性が悪いことが知られています。そのため、以下のおまじないを入れることで、メモリ使用量を大幅に削減できます。

import pypyjit
pypyjit.set_param("max_unroll_recursion=-1")

例えば、次の提出ではこの設定を適用するだけで、メモリ使用量が約84%削減(1093400 KB→175764 KB)され、AC(Accepted)を得ることができました。:Submission #64405730 - AtCoder Beginner Contest 399

PyPyの最適化設定
#!/usr/bin/env python3
import sys

sys.setrecursionlimit(10**7)

import pypyjit
pypyjit.set_param("max_unroll_recursion=-1")

def input():
    return sys.stdin.readline().rstrip()


N, M = map(int, input().split())
G = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    G[u].append(v)
    G[v].append(u)


def dfs(s):
    visited[s] = True
    for v in G[s]:
        if visited[v]:
            continue
        dfs(v)

n_components = 0
visited = [False] * N
for u in range(N):
    if visited[u]:
        continue
    dfs(u)
    n_components += 1

ans = M - N + n_components
print(ans)

まとめ

  • 方法1: 再帰関数を使うときは sys.setrecursionlimit(大きな数字) を適切に設定する。
  • 方法2: BFSやUnion-Findのような再帰を用いない解法はメモリ効率や計算量の面で有利。
  • 方法3: pypyjit.set_param("max_unroll_recursion=-1") を使うとメモリ消費が削減される。

PyPyを用いた競プロでの最適化にぜひ活用してください!

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?