LoginSignup
11
11

More than 5 years have passed since last update.

勾配降下法の結果を matplotlib アニメーション出力する

Last updated at Posted at 2017-04-22

はじめに

前回 に引き続き Pythonからはじめる数学入門 シリーズです。今回は

  • 第 6 章
    • 「6.1.2 図形のアニメーションを作る」
    • 「6.1.3 投射軌跡のアニメーション」
  • 第 7 章
    • 「7.6 勾配上昇法を用いて最大値を求める」
    • 「章末 問題 7-2 勾配降下法を実装する」

が関係しています。

書籍中の説明では結果を matplotlib でプロットして画像 (もちろん静止画) に出力していますが、勾配降下法のアルゴリズムが動作する様子をアニメーションで見たくなりました。そこで今回は gif アニメーションに出力してみました。勾配上昇法ならびに勾配降下法についてはコードの載せるのみで、解説はしません。

やること

まず関数の最小値を勾配降下法を使って求めます。次にステップごとに値が小さくなっていく様子を可視化するために gif 形式でアニメーションを出力します。

今回は

f(x) = 3x^2 + 2x

という二次関数の最小値を求めます。

ソースコード

gradient_descent.py
from sympy import Derivative, Symbol, sympify, solve
from numpy import arange
import matplotlib.pyplot as plt
import matplotlib.animation as ani


def gradient_descent(x0, f1x, x, epsilon=1e-6, step_size=1e-4):
  # f1x = 0 の解を持つか調べる。
  if not solve(f1x):
    return

  x_old = x0
  x_new = x_old - step_size * f1x.subs({x: x_old}).evalf()
  X_traversed = []

  while abs(x_old - x_new) > epsilon:
    X_traversed.append(x_new)
    x_old = x_new
    x_new = x_old - step_size * f1x.subs({x: x_old}).evalf()

  return x_new, X_traversed


def draw_graph(f, x):
  X = arange(-1, 1, 0.01)
  Y = [f.subs({x: x_val}) for x_val in X]

  plt.plot(X, Y)


def draw_frame(i, x, X, Y):
  plt.clf()
  draw_graph(f, x)
  plt.scatter(X[i], Y[i], s=20, alpha=0.8)


if __name__ == '__main__':
  x = Symbol('x')
  f = 3 * x ** 2 + 2 * x
  var0 = 0.75  # 勾配降下法の初期値

  d = Derivative(f, x).doit()

  # gradient_descent() は、勾配降下法で求めた最小値と各ステップでの x の値を返す。
  var_min, X_traversed = gradient_descent(var0, d, x)

  print('総ステップ数: {0}'.format(len(X_traversed)))
  print('最小値 (勾配降下法): {0}'.format(var_min))
  print('最小値 (f1x = 0 の解): {0}'.format(float(solve(d)[0])))

  X = X_traversed[::100] # (1)
  Y = [f.subs({x: x_val}) for x_val in X]

  fig = plt.figure(figsize=(6.5, 6.5))

  anim = ani.FuncAnimation(fig, draw_frame, fargs=(x, X, Y), frames=len(X)) # (2)
  anim.save('gradient_descent.gif', writer='imagemagick', fps=10) # (3)

標準出力

総ステップ数: 10792
最小値 (勾配降下法): -0.331667951428822
最小値 (f1x = 0 の解): -0.3333333333333333

解説

(1) 配列の縮小

X = X_traversed[::100]

X_traversed は配列で、最急降下法の各ステップでの x の値が全て含まれています。

総ステップ数 len(X_traversed) は 10,792 です。仮に 10fps つまり 1 秒間に 10 枚ペースでフレームを描画する場合、アニメーションを終えるのにおよそ 1,079 秒も掛かってしまいます。これを数秒のアニメーションに短縮するために、結果の配列 X_traversed の要素を 100 個置きに取り出した新しい配列 x を生成し、アニメーション作成に使用しています。

この方法だと X_traversed の末尾に格納されている要素、つまり最小値に対応する x の値が除外されてしまいます。しかし、まあアニメーションで雰囲気がつかめればそれでいいと思い妥協しています。

(2) FuncAnimation() の呼び出し

FuncAnimation() を呼び matplotlib.animation.Animation オブジェクトを生成しています。

引数については以下の通りです。

引数 説明
fig グラフの大元である Figure オブジェクト。
draw_frame 各フレームごとに呼ばれる関数。draw_frame の第 1 引数には自動的にフレーム番号が渡される。
fargs draw_frame の第 2 引数以降に渡される値。
frames アニメーションのフレーム数。

(3) Animation.save() の呼び出し

Animation.save() を呼び、実際にアニメーションを保存します。

ここで引数 writerimagemagick を指定することで、gif アニメーションを出力することができました。ただし、使用しているマシンに ImageMagick がインストールされていることが前提です。僕は macOS を使っていますが、ImageMagick をインストールしていなかったので、Homebrew でインストールしました。

$ brew install imagemagick

これ以外に特に設定はしませんでした。

アニメーション

gradient_descent.gif

こいつ・・・動くぞ!

こうやってアルゴリズムが動く様子を可視化するのはとても楽しいですね :blush::hearts:

参考

書籍

リスペクト記事

もっとすごくて強いやつらを見てみたい場合は 確率的勾配降下法とは何か、をPythonで動かして解説する という記事をおすすめします :thumbsup:

上位互換です。僕もこんなアニメーションが出力できればと強いモチベーションになりました。ありがとうございます :pray::sparkles:

その他

11
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
11