LoginSignup
12
9

More than 5 years have passed since last update.

勾配降下法の勾配とは何か、可視化してみる

Last updated at Posted at 2018-10-10

はじめに

勾配降下法の勾配のイメージを鮮明にするため主にmatplotlibを使って可視化してみたいと思います。

勾配降下法の詳細については拙著以下の記事を参考にしてみてください。
パーセプトロン(ADALINE)でライブラリを用いず多クラス分類してみる【Python】

勾配を可視化する

今回勾配を可視化するにあたって、例として以下の関数を用います。
$$
z = x^2 + xy
$$

勾配の計算

勾配ベクトルの定義は

$$
\nabla z = \left( \frac{\partial z}{\partial x}, \frac{\partial z}{\partial y} \right)
$$

なので、実際に偏微分を計算すると、

$$
\begin{align}
\nabla z &= \left( \frac{\partial z}{\partial x}, \frac{\partial z}{\partial y} \right) \\
&=\left( 2x + y, x \right)
\end{align}
$$

のようになります。

これらのベクトルを実際に図示して見てみます。

関数と偏微分の定義

まず、関数と偏微分を定義します。

# 元の関数
def f(x, y):
    return x**2 + x*y

# xでの偏微分
def partial_diff_x(x, y):
    return 2*x + y

# yでの偏微分
def partial_diff_y(x):
    return x

ベクトルを描写

次にベクトルを描写してみます。

import matplotlib.pyplot as plt

# グリッドデータ作成
x = np.arange(-3, 3, 0.25)
y = np.arange(-3, 3, 0.25)
xx, yy = np.meshgrid(x, y)

# 偏微分の計算
partial_diff_x_matrix = partial_diff_x(xx, yy)
partial_diff_y_matrix = partial_diff_y(xx)

# 描写
plt.quiver(xx, yy, partial_diff_x_matrix, partial_diff_y_matrix)
plt.show()

このようなグラフが得られます。

矢印はグラフの斜面の高い方へ向いており、矢印が長いほど、その斜面が急であることを示しています。

3次元グラフに図示してみると・・・

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 高さの計算
z = f(xx, yy)

# 描写
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(xx, yy, z)

実際に、矢印はグラフの斜面の高い方へ向いており、その矢印が長いほど、その斜面が急であることがわかると思います。

まとめ

今回は勾配を可視化してみました。

この記事を読んで少しでも多くの人が勾配のイメージをつかめてくれると嬉しいですー!

12
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
9