3次元のグラフの書き方
ニューラルネットワークの学習で現れる,偏微分(勾配)の理解のため3次元グラフを描いてみる.(参考:数値微分)
######Pythonで,NumPyとmatplotlibを使って3次元グラフを描く
準備
- 3次元なので
mpl_toolkits.mplot3d
などをインポート
ex1-1.py
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
- 引数の2乗和を計算する関数を例として考える
ex1-2.py
def func1(x, y):
return x**2 + y**2
- 描写データの作成
- 3次元で描写するには2次元メッシュが必要
- 2次元配列をarangeを用いて作る
- x, y をそれぞれ1次元領域で分割する
ex1-3.py
x = np.arange(-3.0, 3.0, 0.1)
y = np.arange(-3.0, 3.0, 0.1)
- 2次元メッシュはmeshgridでつくる
- Xの行にxの行列を,Yは列にyの配列を入れたものになっている
ex1-3.py
X, Y = np.meshgrid(x, y)
Z = func1(X, Y)
- グラフの作成
- figureで2次元の図を生成する
- その後,Axes3D関数で3次元にする
ex1-4.py
fig = plt.figure()
ax = Axes3D(fig)
- 軸ラベルの設定
ex1-5.py
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("f(x, y)")
-グラフ描写
ex1-6.py
ax.plot_wireframe(X, Y, Z)
plt.show()