大学の課題で出されたフラクタル図形の**マンデルブロ集合**で遊びまくったお話です。
課題ではマンデルブロ集合を $-2.5 < x < 1.0\space, \space-1.0 < y < 1.0$ の範囲を $400×400$ のドット絵でを作りました。
しかし、 $400×400$ のドット絵だと輪郭がガタガタでせっかくできたのに感動が足りない野で、もっとピクセル数を増やしたい!
そこで使ったのがnumbaです。
Let's pip install numba
numbaとは
numbaは、Pythonで実装された関数を実行中にコンパイルするJITコンパイラを提供するパッケージです。(実行中→Just-In-Time)
そのため、for
文を高速で回すことができます。フラクタル図形を描画するには、for文をぶんぶん回さなければいけないので今回の目的にぴったり。
単純なコードで実行速度を比較してみる
1〜n までの総和 $\left(\sum_{k=1}^nk\right)$ を
- Pythonのみで実装した関数
- 1の関数に
@jit
デコレータをつけた関数
の両方で求めて、その実行速度を比較してみました。
実行時間にばらつきがあったので、下のコードは平均の時間が計算しやすいjupyterを使いました。
コード
from numba import jit
import matplotlib.pyplot as plt
%matplotlib inline
def sum_py(n):
s = 0
for i in range(n):
s += i
return s
@jit
def sum_jit(n):
s = 0
for i in range(n):
s += i
return s
sum_jit(0) # 初回の実行はコンパイルの時間を含むので、先に実行しておく
ns = [10 ** i for i in range(2, 7)]
time_py =[]
time_jit = []
for n in ns:
a = %timeit -r 3 -n 100 -o sum_py(n)
b = %timeit -r 3 -n 100 -o sum_jit(n)
time_py.append(a.average)
time_jit.append(b.average)
plt.figure()
plt.plot(ns, time_py, label="python")
plt.plot(ns, time_jit, label="jit")
plt.xscale("log")
plt.yscale("log")
plt.legend()
実行結果
縦軸:時間(秒)、横軸:ループ回数
JITでコンパイルした関数は、ループ回数を増やしても全く遅くならない!
フラクタル図形の計算を高速化する
元々のコード
下がnumbaで高速化する前の(1番上の図を描画した)コードです。
import matplotlib.pyplot as plt
from numba import jit
import numpy as np
# 領域の大きさ
(x_min, x_max) = (-2.5, 1.0)
(y_min, y_max) = (-1.0, 1.0)
# 小領域の数 = M x N
(M, N) = (400,) * 2
# 各座標
xs = np.linspace(x_min, x_max, M)
ys = np.linspace(y_min, y_max, N)
# 発散の評価
K = 1000
def is_to_infty(z0):
z = z0
for k in range(K):
if abs(z) > 2:
return np.sqrt(k/K)
z = (z ** 2) + z0
return 1.
def mandelbrot(xs, ys):
P = np.zeros((N, M))
for j, y in enumerate(ys):
for i, x in enumerate(xs):
P[j, i] = is_to_infty(complex(x, y))
return P
img = mandelbrot(xs ,ys)
plt.figure(figsize=((x_max - x_min) * 3.5, (y_max - y_min) * 3.5))
plt.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max), cmap=cm.jet, interpolation='nearest')
plt.show()
# plt.savefig("mandelbrot.png", dpi=N//8)
この記事のテーマはnumbaなので、フラクタル図形のマンデルブロ集合についての詳しい解説は省略します。
このコードはmandelbrot
関数内で $400^2$ 回 for
文が繰り返されているので、そこに時間がかかってしまいます。
matplotlibのオプションについて
- そもそも
for
文どうこうよりもplt.imshow(img)
がめちゃくちゃ時間かかります。 - せっかく
N, M
を大きくしても、dpiを指定しないとそのままの画質で保存できません。 - だいたい
img=N//8
にすると、img
の要素数がそのまま画像に画素数になります。(10:16画面の場合)
numbaをつかう
とは言っても、変更点は関数に@jit
をつけるだけ
@jit
def is_to_infty(z0):
...
@jit
def mandelbrot(xs, ys):
...
実行結果の比較
コードの1部分を下のようにして実行結果を比較してみると、
from time import time
start = time()
img = mandelbrot(xs ,ys)
end = time()
print(end - start)
結果
python: 6.6 秒
numba : 0.8 秒
約8倍。しかしこの差は、先ほどの総和の実行速度のグラフのようにfor
文の繰り返し回数が多くなるほど大きくなります。つまり、$400^2$ 回から$10000^2$ 回にすると…パソコンの寿命が先に来そうですね。
もっと細かく!(本題)
さてやっと本題です。numbaはただの手段であって今回の目的はきれいなマンデルブロ集合を見たいということです!
(M, N) = (10000,) * 2
えーっと、1万回ループのネストは1億回ループかな笑
どんな画像ができるんだろう
わくわくどきどき
実際に出力された画像は約6MBで載せれないので、画像が欲しかったらmandelbrot
関数に@jit
をつけて1億回ループを各自ぶん回してください。(2分もしないはずです。)
下の画像(809KB)は、一部を拡大して最初のと比較したものです。
後記
美しい。YouTubeによくある、1点をずっと拡大し続けるアニメーションとかできたら面白そう。