はじめに
前回 に引き続き Pythonからはじめる数学入門 シリーズです。今回は
- 第 6 章
- 「6.1.2 図形のアニメーションを作る」
- 「6.1.3 投射軌跡のアニメーション」
- 第 7 章
- 「7.6 勾配上昇法を用いて最大値を求める」
- 「章末 問題 7-2 勾配降下法を実装する」
が関係しています。
書籍中の説明では結果を matplotlib でプロットして画像 (もちろん静止画) に出力していますが、勾配降下法のアルゴリズムが動作する様子をアニメーションで見たくなりました。そこで今回は gif アニメーションに出力してみました。勾配上昇法ならびに勾配降下法についてはコードの載せるのみで、解説はしません。
やること
まず関数の最小値を勾配降下法を使って求めます。次にステップごとに値が小さくなっていく様子を可視化するために gif 形式でアニメーションを出力します。
今回は
f(x) = 3x^2 + 2x
という二次関数の最小値を求めます。
ソースコード
from sympy import Derivative, Symbol, sympify, solve
from numpy import arange
import matplotlib.pyplot as plt
import matplotlib.animation as ani
def gradient_descent(x0, f1x, x, epsilon=1e-6, step_size=1e-4):
# f1x = 0 の解を持つか調べる。
if not solve(f1x):
return
x_old = x0
x_new = x_old - step_size * f1x.subs({x: x_old}).evalf()
X_traversed = []
while abs(x_old - x_new) > epsilon:
X_traversed.append(x_new)
x_old = x_new
x_new = x_old - step_size * f1x.subs({x: x_old}).evalf()
return x_new, X_traversed
def draw_graph(f, x):
X = arange(-1, 1, 0.01)
Y = [f.subs({x: x_val}) for x_val in X]
plt.plot(X, Y)
def draw_frame(i, x, X, Y):
plt.clf()
draw_graph(f, x)
plt.scatter(X[i], Y[i], s=20, alpha=0.8)
if __name__ == '__main__':
x = Symbol('x')
f = 3 * x ** 2 + 2 * x
var0 = 0.75 # 勾配降下法の初期値
d = Derivative(f, x).doit()
# gradient_descent() は、勾配降下法で求めた最小値と各ステップでの x の値を返す。
var_min, X_traversed = gradient_descent(var0, d, x)
print('総ステップ数: {0}'.format(len(X_traversed)))
print('最小値 (勾配降下法): {0}'.format(var_min))
print('最小値 (f1x = 0 の解): {0}'.format(float(solve(d)[0])))
X = X_traversed[::100] # (1)
Y = [f.subs({x: x_val}) for x_val in X]
fig = plt.figure(figsize=(6.5, 6.5))
anim = ani.FuncAnimation(fig, draw_frame, fargs=(x, X, Y), frames=len(X)) # (2)
anim.save('gradient_descent.gif', writer='imagemagick', fps=10) # (3)
標準出力
総ステップ数: 10792
最小値 (勾配降下法): -0.331667951428822
最小値 (f1x = 0 の解): -0.3333333333333333
解説
(1) 配列の縮小
X = X_traversed[::100]
X_traversed
は配列で、最急降下法の各ステップでの x
の値が全て含まれています。
総ステップ数 len(X_traversed)
は 10,792 です。仮に 10fps つまり 1 秒間に 10 枚ペースでフレームを描画する場合、アニメーションを終えるのにおよそ 1,079 秒も掛かってしまいます。これを数秒のアニメーションに短縮するために、結果の配列 X_traversed
の要素を 100 個置きに取り出した新しい配列 x
を生成し、アニメーション作成に使用しています。
この方法だと X_traversed の末尾に格納されている要素、つまり最小値に対応する x の値が除外されてしまいます。しかし、まあアニメーションで雰囲気がつかめればそれでいいと思い妥協しています。
(2) FuncAnimation() の呼び出し
FuncAnimation() を呼び matplotlib.animation.Animation オブジェクトを生成しています。
引数については以下の通りです。
引数 | 説明 |
---|---|
fig | グラフの大元である Figure オブジェクト。 |
draw_frame | 各フレームごとに呼ばれる関数。draw_frame の第 1 引数には自動的にフレーム番号が渡される。 |
fargs | draw_frame の第 2 引数以降に渡される値。 |
frames | アニメーションのフレーム数。 |
(3) Animation.save() の呼び出し
Animation.save() を呼び、実際にアニメーションを保存します。
ここで引数 writer
に imagemagick
を指定することで、gif アニメーションを出力することができました。ただし、使用しているマシンに ImageMagick がインストールされていることが前提です。僕は macOS を使っていますが、ImageMagick をインストールしていなかったので、Homebrew でインストールしました。
$ brew install imagemagick
これ以外に特に設定はしませんでした。
アニメーション
こいつ・・・動くぞ!
こうやってアルゴリズムが動く様子を可視化するのはとても楽しいですね
参考
書籍
リスペクト記事
もっとすごくて強いやつらを見てみたい場合は 確率的勾配降下法とは何か、をPythonで動かして解説する という記事をおすすめします
上位互換です。僕もこんなアニメーションが出力できればと強いモチベーションになりました。ありがとうございます