はじめに
こんにちは。こんばんはかもしれません。爲岡 (ためおか) と申します。
2020年4月から株式会社グロービスにて機械学習エンジニアとして働いています。
グロービスでは機械学習技術を利用したプロジェクトや、データ基盤の運用改善プロジェクトを担当しています。
機械学習技術を利用したシステムには様々なものがあると思いますが、
現状のグロービスにおいては、ユーザのリクエストに対して機械学習を行い、
すぐに結果を返す必要があるようなシステムは扱っておらず、
ある程度の時間をかけて学習、推定した結果を非同期にアプリケーションに連携するシステムのみを扱っています。
ゆえに、今のところは機械学習技術を利用したロジックを書くときに、速さを意識することはあまりありません。
一方で、空いている時間に競技プログラミングをやっていることもあり、
高速なコードについて考えたり、書いたりすることは個人的には好きです。
今回は、機械学習において利用されることが多い Python のコードについて、
色々な手段を使って実際に高速化を試みつつ、その手順をまとめました。
よろしければご覧いただければと思います。
実行環境
ハードウェア
- Mac Book Pro 2019 年モデル
- プロセッサ: 2.6 GHz 6 コア Intel Core i7
- メモリ: 16 GB
ソフトウェア
- Python 実行環境の Docker image には
python:3.9.1-buster
を利用します。 - リソースはいっぱいいっぱいまで利用できるよう設定しています。
ディレクトリ構成
ディレクトリ構成はこんな感じです。各ファイルの内容に関しては後述します。
fast_python
ディレクトリ以下を Docker コンテナ内の/
に mount して利用しています。
fast_python
├── cython
│ ├── prime.c
│ ├── prime.cpython-37m-x86_64-linux-gnu.so
│ ├── prime.py
│ ├── prime.pyx
│ └── setup.py
├── pybind11
│ ├── prime.cpp
│ ├── prime.cpython-37m-x86_64-linux-gnu.so
│ └── prime.py
└── python
├── fast_prime.py
└── prime.py
高速化の対象
今回は下記のような問題を考え、その回答となるコードの高速化を試みます。
標準入力で与えられる整数 X について、
X 以上の素数のうち、最小のものを求め、標準出力せよ。
ただし、2≤X≤10^14 とする。
※ 参考: AtCoder Beginner Contest 149 C - Next Prime
速さを意識せずに書く
まずは高速化のことは一旦置いて、ただ問題を解く方法について考えます。
2≤X≤10^14 という制約条件を除けば、X 以上の整数について、1つずつ「素数かどうか」を判定し、
もし素数であればその値を返す、というシンプルな問題になるかと思います。
これをコードに表すとこんな感じになりそうです。
def is_prime(x: int) -> bool: # 素数かどうかの判定を行う関数
# まだ考えていない
def minimum_prime_number():
X = int(input().strip())
answer = 0
for i in range(X, (10 ** 14 + 32)):
if is_prime(i):
answer = i
break
print(answer)
if __name__ == '__main__':
minimum_prime_number()
ここで、minimum_prime_number()
の処理の中で、for 文のrange
の stop を10^14 + 32
としていますが、
その理由は 10^14 以上の最小の素数が 10^14 + 31 であるためになります。
制約条件下の X の最大値である 10^14 が入力された場合、 10^14 + 31 が出力されれれば良いため、
今回のケースではこのrange
で十分であるはずです。
次に、「素数かどうか」の判定ロジックについて考えます。
素数とは「1 より大きい自然数で、かつ正の約数が 1 とその数自身のみであるもの」なので、
整数 X が素数かどうかを判定するロジックは下記のようになるかと思います。
今回は 2≤X という制約条件があるため、X が 1 以下のケースについては考えません。
2 から X までの整数 i について、X を i で割ったときの余りをそれぞれ計算する。
もし、余りが 0 となる整数 i が存在する場合は、 False を返す。
存在しない場合は True を返す。
上記のロジックを Python のプログラムに落とすと、下記のようになりました。
def is_prime(x: int) -> bool:
if x <= 1:
return False
for i in range(2, x):
if x % i == 0:
return False
return True
def minimum_prime_number():
X = int(input().strip())
answer = 0
for i in range(X, (10 ** 14 + 32)):
if is_prime(i):
answer = i
break
print(answer)
if __name__ == '__main__':
minimum_prime_number()
実際にプログラムを実行してみます。
# @Docker コンテナ内
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
2 # 入力
2 # 出力
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
3 # 入力
3 # 出力
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
4 # 入力
5 # 出力
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
5 # 入力
5 # 出力
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
6 # 入力
7 # 出力
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
7 # 入力
7 # 出力
... (以下略)
入力された整数以上の素数のうち、最小のものが出力されており、良さそうです。
次はコーナーケースとして 10^14 を入力してみます。
出力としては 10^14 + 31 を期待しています。
root@xxxxxxxxxxxx:/fast_python/python# python prime.py
1000000000000 # 入力
...
おや、全然結果が返ってきません。
一応 10 分くらい待ってみましたが、結果が返ってきませんでした。
入力された数字が大きすぎるために for 文のループ回数が多くなり、処理に非常に時間がかかってしまっているようです。
Python における高速化
今回の問題では処理時間の上限を設けていませんが、10 分経っても結果が返ってこないのはちょっと困ります。
このコードをなんとか高速化して、さっさと結果が返ってくるようにしたいです。
コードの高速化の第一歩として、まずは、どこの処理が遅いのかを特定するのが良いと思います。
特定の方法はいくつかありそうですが、プロファイラを利用すると、各処理にかかる時間が詳細にわかります。
プロファイラ cProfile
を使う
Python のプロファイリングツールもこれまた色々とあるみたいですが、
組み込みのツールとしてはcProfile
というものがあります。
これを使って、各関数の実行時にかかっている時間を見ていきましょう。
公式ドキュメントを参照すると、下記のようなコマンドを実行するだけで
指定の Python コードファイルのプロファイリングができるようです。
root@xxxxxxxxxxxx:/fast_python/python# python -m cProfile prime.py
今回は各処理にかかっている時間を確認したいので、tottime
順に sort されるように
下記のように-s
オプションを付けて実行します。
root@xxxxxxxxxxxx:/fast_python/python# python -m cProfile -s tottime prime.py
tottime
というのは何かと言うと、
「与えられた関数に消費された合計時間 (sub-function の呼び出しで消費された時間は除外されています)」だそうです。
実際に上記のコマンドを実行すると、下記のような出力が得られました。
実行の際の標準入力には、多少処理に時間がかかりますがちゃんと出力が返ってくる、10^7 を指定してみました。
root@xxxxxxxxxxxx:/fast_python/python# python -m cProfile -s tottime prime.py
10000000
10000019
237 function calls in 2.342 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 1.675 1.675 1.675 1.675 {built-in method builtins.input}
20 0.664 0.033 0.664 0.033 prime.py:1(is_prime)
4 0.000 0.000 0.002 0.000 <frozen importlib._bootstrap_external>:1438(find_spec)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:62(_path_join)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:64(<listcomp>)
... (以下略)
プロファイラの出力結果を見ると、一番時間がかかっている処理は下記のようです。
ncalls tottime percall cumtime percall filename:lineno(function)
1 1.675 1.675 1.675 1.675 {built-in method builtins.input}
ただ、これは標準入力がされるまでの待ち時間が含まれています。僕の標準入力のタイピングが遅いということです。
要するにこれは無視して良いと思います。
問題は2行目です。
ncalls tottime percall cumtime percall filename:lineno(function)
20 0.664 0.033 0.664 0.033 prime.py:1(is_prime)
is_prime()
が実際の処理の中で一番時間がかかっている関数であり、全体で 0.6 秒以上かかっています。
見たところ他の処理にかかる時間は 0.001 秒以下で無視してよく、ここがボトルネックであるとわかります。
速い Python コードにする
高速化が難しいようなプログラムもあるとは思いますが、今回のケースではロジックの改修による高速化を試みます。
for 文のループ回数が多いために処理に時間がかかっている場合、処理の高速化のためにまず思いつくこととしては、
ループ回数を減らすことだと思います。
実際、今回のケースでは、問題の解答のために必要な条件を満たしつつ、ループ回数を減らすことができます。
is_prime()
において、ループ回数は 2 から「X の平方根以下の最大の整数」まででよいです。
証明はググると出てきますが、こちらのサイトがわかりやすかったです。
※ SSL 対応されていないサイトのためご注意ください。
これを実際にコードに適用したものが下記になります。
import math
def is_prime(x: int) -> bool:
if x <= 1:
return False
for i in range(2, (math.floor(math.sqrt(x)) + 1)): # 平方根以下の最大の整数を上限に設定
if x % i == 0:
return False
return True
def minimum_prime_number():
X = int(input().strip())
answer = 0
for i in range(X, (10 ** 14 + 32)):
if is_prime(i):
answer = i
break
print(answer)
if __name__ == '__main__':
minimum_prime_number()
実際に実行してみると下記のようになりました。標準入力は 10^7 から変更無しです。
root@xxxxxxxxxxxx:/fast_python/python# python -m cProfile -s tottime fast_prime.py
10000000
10000019
277 function calls in 1.744 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 1.738 1.738 1.738 1.738 {built-in method builtins.input}
20 0.001 0.000 0.001 0.000 fast_prime.py:4(is_prime) # is_prime() の処理時間
4 0.001 0.000 0.002 0.000 <frozen importlib._bootstrap_external>:1438(find_spec)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:62(_path_join)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:64(<listcomp>)
... (以下略)
こんな感じで、is_prime()
の処理時間が 0.001 秒まで短縮されています。
高速化する前のコードのis_prime()
は 0.664 秒かかっていたので、
処理時間は高速化する前と比べて、実に 1/664 となりました。
これなら、先程は処理に 10 分以上かかった 10^14 を標準入力に与えても結果が返ってきそうです。
root@xxxxxxxxxxxx:/fast_python/python# python fast_prime.py
100000000000000
100000000000031
実際に試してみると、数秒して結果が返ってきました。嬉しいです。
プロファイリングして処理時間を計測してみます。
root@xxxxxxxxxxxx:/fast_python/python# python -m cProfile -s tottime fast_prime.py
100000000000000
100000000000031
318 function calls in 3.529 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 2.602 2.602 2.602 2.602 {built-in method builtins.input}
32 0.922 0.029 0.922 0.029 fast_prime.py:4(is_prime) # is_prime() の処理時間
4 0.000 0.000 0.002 0.001 <frozen importlib._bootstrap_external>:1356(find_spec)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:62(_path_join)
16 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:64(<listcomp>)
... (以下略)
is_prime()
の処理時間は 0.922 秒で済みました。
高速化する前は 10 分経っても結果が返ってこなかったわけですから、比較すると大きな差です。
Cython を利用して高速化する
このように、問題の解答のために必要な条件を満たしつつ、ループ回数を減らすことができましたが、
さらに高速化を行う方法を考えてみます。
ロジックを修正することによる高速化はこれ以上見込め無さそうですが、
Cython を利用すると、既存のロジックを変更すること無く、処理を高速化することができる場合があります。
使い方としては、まずは下記のようにpip install
します。
root@xxxxxxxxxxxx:/fast_python/cython# pip3 install cython
次に、下記のように高速化したい処理を、.pyx
という拡張子でファイルに切り出します。
ここに Cython を利用したコードを書いていくのですが、
Cython を利用した高速化の一般的なアプローチとして、利用する変数を下記のように
cdef
という形で C 言語の変数として宣言する方法があります。今回はこれを試してみます。
import cython
import math
def is_prime(x: int) -> bool:
cdef:
long i, stop # C 言語の型を指定して変数宣言する
stop = math.floor(math.sqrt(x) + 1)
if x <= 1:
return False
for i in range(2, stop):
if x % i == 0:
return False
return True
次に以下のようなセットアップファイルを用意します。
※ 参考: cython入門 - Qiita
from distutils.core import setup, Extension
from Cython.Build import cythonize
ext = Extension("prime", sources=["prime.pyx"])
setup(name="prime", ext_modules=cythonize([ext]))
これらが用意できたら、下記のコマンドを実行します。
root@xxxxxxxxxxxx:/fast_python/cython# python setup.py build_ext --inplace
実行すると、カレントディレクトリにprime.c
という C のファイルと
prime.cpython-37m-x86_64-linux-gnu.so
という共有ライブラリのファイルが作られます。
これによって、下記のように、Cython を利用して定義したis_prime()
を
Python コード内でimport
して利用できるようになります。
from prime import is_prime # Cython を利用して定義した is_prime() を import
def minimum_prime_number():
X = int(input().strip())
answer = 0
for i in range(X, (10 ** 14 + 32)):
if is_prime(i):
answer = i
break
print(answer)
if __name__ == '__main__':
minimum_prime_number()
このコードを実行し、プロファイリングしてみます。
root@xxxxxxxxxxxx:/fast_python/cython# python -m cProfile -s tottime prime.py
100000000000000
100000000000031
355 function calls (348 primitive calls) in 3.416 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 2.755 2.755 2.755 2.755 {built-in method builtins.input}
32 0.655 0.020 0.655 0.020 {prime.is_prime} # is_prime() の処理時間
2 0.001 0.001 0.001 0.001 {built-in method _imp.create_dynamic}
5 0.001 0.000 0.002 0.000 <frozen importlib._bootstrap_external>:1356(find_spec)
17 0.000 0.000 0.001 0.000 <frozen importlib._bootstrap_external>:56(_path_join)
... (以下略)
ボトルネックであるis_prime()
の処理時間は 0.655 秒となりました。
先程の処理時間 0.922 秒と比べると、処理時間は 2/3 程度となり、さらなる高速化ができたことになります。
より速いプログラミング言語で書く
コードを高速化したり、Cython を利用したりすることで処理時間は短くなりましたが、
さらに高速化が求められることもあるでしょう。 (言っておいてなんですがそんなにないかも。)
そういう状況においては、より速いプログラミング言語で書くということもアプローチの 1 つとしてあり得ると思います。
今回は Python で書いていたis_prime()
関数のコードを、
高速なプログラミング言語である C++ で書き直した後、Python コードから呼び出して実行してみます。
is_prime()
関数を C++ で書き直したコードは下記になります。
bool is_prime(long x) {
for (long i = 2; i <= sqrt(x); i++) {
if (x % i == 0)
return false;
}
return true;
}
この処理を Python のコードから呼び出したいです。
これもやり方としてはいくつかありそうですが、今回は手軽に利用できるpybind11
を使いました。
Cython と同様、まずは下記のようにpip install
します。
root@xxxxxxxxxxxx:/fast_python/pybind11# pip3 install pybind11
次に、先程作成した C++ のコードに対して、下記のように binding するためのコードを追加します。
#include <pybind11/pybind11.h> // ここを追加
bool is_prime(long x) {
for (long i = 2; i <= sqrt(x); i++) {
if (x % i == 0)
return false;
}
return true;
}
PYBIND11_MODULE(prime, m) { // ここから
m.def("is_prime", &is_prime); //
} // ここまで追加
最後に、下記のコマンドを実行してコンパイルします。
※ 参考: 【Techの道も一歩から】第23回「pybind11を使ってPythonで利用可能なC++ライブラリを実装する」
root@xxxxxxxxxxxx:/fast_python/pybind11# g++ -O2 -Wall -shared -std=c++11 -fPIC `python3 -m pybind11 --includes` prime.cpp -o prime`python3-config --extension-suffix`
実行すると、Cython のときと同様、カレントディレクトリにprime.cpython-37m-x86_64-linux-gnu.so
という
共有ライブラリのファイルが作られます。
これで C++ で定義した関数を Python コード内でimport
して利用できるようになりました。
from prime import is_prime # C++ で定義した関数 is_prime() を import
def minimum_prime_number():
X = int(input().strip())
answer = 0
for i in range(X, (10 ** 14 + 32)):
if is_prime(i):
answer = i
break
print(answer)
if __name__ == '__main__':
minimum_prime_number()
これを実行して、実際にcProfile
を使って処理時間を計測してみると、下記のようになりました。
root@xxxxxxxxxxxx:/fast_python/pybind11# python -m cProfile -s tottime prime.py
100000000000000
100000000000031
138 function calls in 2.760 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 2.670 2.670 2.670 2.670 {built-in method builtins.input}
32 0.085 0.003 0.085 0.003 {built-in method prime.is_prime} # is_prime() の処理時間
1 0.002 0.002 0.002 0.002 {built-in method _imp.create_dynamic}
1 0.000 0.000 2.756 2.756 prime.py:4(minimum_prime_number)
1 0.000 0.000 0.001 0.001 <frozen importlib._bootstrap>:882(_find_spec)
... (以下略)
is_prime()
の処理時間は 0.085 秒となりました。
Cython を利用した場合のis_prime()
の処理時間は 0.655 秒だったので、
そこからさらに 1/8 程度まで処理時間を短縮できました。
他にもまだまだ高速化の手段はありそうですが、今回はここまでとさせてください。
まとめ
各手段に対するボトルネック (is_prime()
) の処理時間の対応表を下記にまとめました。
これらの処理時間は、全て 10^14 を入力としたときのものになります。
|手段|ボトルネックの処理時間 (秒)|
|---|---|---|
|特に対処無し|600 秒以上|
|for ループ回数を減らす高速化|0.922 秒|
|for ループ回数を減らす高速化 & Cython を利用|0.655 秒|
|for ループ回数を減らす高速化 & C++ & pybind11
を利用|0.085 秒|
以上のように、Python コードの高速化の手段は色々とあることがわかりました。
また、ロジックの中で高速化できる箇所が高速化しつつ、C++ とpybind11
を使えば
かなりの高速化が見込めることがわかりました。
実際、特に対処無しのコードと比べると、C++ とpybind11
を利用した場合の処理時間は 1/7000 以下になっています。
今回は速さという観点に限った検証ですので、これだけを見ると
pybind11
をガンガン使っていけば速くなるしいいじゃん、となりそうですが、
実際の現場においては、保守性や工数、リソース状況など、様々な要素を考慮しつつ方針を決めるのが良いと思います。
最後までお読みいただき、ありがとうございました。