前回の記事「PyTorchの自動微分を用いて関数の接線を描画してみた」に続いて、2変数関数の接平面を描画するコードを書いてみました。流れは接線の場合と同じです。
接平面の方程式を求める
前回の接線との対比で言えば、次元を一つ上げて、関数を$z=f(x,y)$とし、$(x,y)=(x_0,y_0)$での関数の勾配ベクトルを$(a,b)$とすれば、接平面の方程式の一般形は$z=ax+by+c=(a,b)\cdot(x,y)+c$となり、$\left(x_0, y_0, f(x_0, y_0)\right)$を通るように定数$c$を調整すればよいということになります。
f(x_0, y_0) = (a, b) \cdot ( x_0, y_0) +c \\
\therefore c = f(x_0, y_0) - (a, b) \cdot ( x_0, y_0)
接線の方程式と比較すると、変数が1次元から2次元になる以外は、普通の積が内積に変わるだけです。
次のコードは、この結果に基づいて接平面の方程式を求める関数です。後で接平面を描画するとき使うために、勾配ベクトル$(a,b)$と定数$c$もリターン値に含めています。
import numpy as np
import matplotlib.pyplot as plt
import torch
def tangent_surface2d(f, xy0:torch.Tensor):
z0 = f(xy0) # z0 = f(x0, y0)
z0.backward()
with torch.no_grad():
dxy = xy0.grad.clone() # 後のxy0.grad.zero_()で0クリアされた内容を参照しないようにコピーしておく!
c = z0 - dxy@xy0 # c = f(x0,y0) - (a,b)*(x0,y0)
xy0.grad.zero_()
c = c.numpy()
dxy = dxy.data.numpy()
return lambda xy : xy@dxy + c, dxy, c
引数xy0
には$(x_0,y_0)$を表すテンソルを渡します。まず、z0 = f(xy0)
で$z_0=f(x_0,y_0)$とし、z0.backward()
とdxy = xy0.grad.clone()
をコールして、関数$f(x,y)$の$(x,y)=(x_0,y_0)$での勾配ベクトル$(a,b)$を求めます。次に、c = z0 - dxy@xy0
で$c = f(x_0, y_0) - (a, b) \cdot ( x_0, y_0)$を計算します。
次のコードはこの関数を呼び出す例です。特に、第2引数xy0
は、requires_grad_()
をコールしてbackward()
で自動微分するテンソルであることを指定しておく必要があります。
def f(x): return torch.sin(torch.sqrt((x**2).sum(axis=-1)))
x0 = -1.7
y0 = -0.2
xy0 = torch.tensor([x0, y0]).requires_grad_()
tf, dxy, c = tangent_surface2d(f, xy0)
PyTorchの自動微分については@sudominoruさんの「【PyTorchチュートリアル②】Autograd: 自動微分」に分かりやすい説明があります。
描画用のツール関数を用意する
次のコードは、2変数関数$z=f(x,y)$の曲面と接平面、勾配ベクトルをそれぞれ描画する関数です。内容はコメントを見れば大体わかります。
# 2変数関数z=f(x,y)の描画
def draw_function2d(ax, f, N, xrange=[-4., 4.], yrange=[-4., 4.]):
X, Y = np.meshgrid( # x-y平面上の等間隔の格子座標を得る
np.linspace(xrange[0], xrange[1], N), # x軸 N点
np.linspace(yrange[0], yrange[1], N) # y軸 N点
)
XY = np.c_[np.ravel(X), np.ravel(Y)] # (x, y)の配列 shape=(N**2, 2)
Z = f(torch.tensor(XY)).numpy() # z = f(x, y)
Z = Z.reshape(X.shape) # Zのshapeを(N,N)に変更
return ax.plot_surface(X, Y, Z, cmap='bwr', linewidth=0, alpha=0.3)
# 接平面の描画
def draw_tanjent_plane(ax, tf, x, xw=1.5, yw=1.5):
X, Y = np.meshgrid( # 指定した(x,y)を中心とする接平面4頂点のx-y座標を得る
np.linspace(x[0]-xw, x[0]+xw, 2), # 接平面の頂点x座標
np.linspace(x[1]-yw, x[1]+yw, 2) # 接平面の頂点y座標
)
XY = np.c_[np.ravel(X), np.ravel(Y)] # 接平面4頂点のx,y座標の配列 shape=(4, 2)
Z = tf(XY).reshape(X.shape) # 接平面4頂点のZ座標を求めてshapeを(2,2)に変更
ax.plot_surface(X, Y, Z, color='grey', alpha=0.3) # 接平面の描画
ax.scatter3D(x[0], x[1], 0., c='b', s=20, marker='x') # (x0, y0, 0)
ax.scatter3D(x[0], x[1], f(x), c='b', s=20, marker='x') # (x0, y0, f(x0, y0))
# line (x0, y0, 0) --> (x0, y0, f(x0, y0))
ax.plot([x[0],x[0]], [x[1], x[1]], [0, f(x)],linestyle = "dashed", c='b', linewidth=1)
# 勾配ベクトルの描画
def draw_grad_vector(ax, f, x, grad):
ax.quiver(x[0], x[1], f(x), 1.0, 1.0, grad[0]+grad[1], arrow_length_ratio=0.08)
勾配ベクトルの描画はコードだけでは分かりにくいですが、$x$軸方向の勾配ベクトル$\left(1, 0, \frac{\partial f(x,y)}{\partial x}\right)$と$y$軸方向の勾配ベクトル$\left(0, 1, \frac{\partial f(x,y)}{\partial y}\right)$の和を描画しています。
接平面を描画する
次のコードは、上記のツール関数を用いて、2変数関数の曲面と接平面、勾配ベクトルを描画するものです。ここでは、とりあえず、$f(x,y)=sin(\sqrt{x^2+y^2})$としています。また、x0
とy0
は$(x_0,y_0)$を表す変数で、フォーム項目としてセルの実行時に変更可能にしています。また、N
は、関数$z=f(x,y)$の曲面を面分割するメッシュのX軸とY軸方向の格子点の数で、同じくフォーム項目としてセルの実行時に変更可能にしています。
ちなみに、get_equation_symbpl
という関数は、得られた接平面の式を文字列として表示するためだけのもので、係数$a$、$b$、$c$の符号にうまく対処したくて書きました。
(注)フォーム項目はGoogle Colab固有の機能のようです
import sympy
# 平面の方程式 z = ax + by + c
def get_equation_symbpl(a, b, c):
x = sympy.Symbol('x')
y = sympy.Symbol('y')
return a*x + b*y + c
from mpl_toolkits.mplot3d import Axes3D
########### 2変数関数f(x, y)
def f(x): return torch.sin(torch.sqrt((x**2).sum(axis=-1)))
########### 指定点(x0,y0)での接平面の方程式を得る
x0 = -1.7#@param {type:"number"}
y0 = -0.2#@param {type:"number"}
xy0 = torch.tensor([x0, y0]).requires_grad_() # 接平面を計算する点のx-y座標(x0, y0)
tf, dxy, c = tangent_surface2d(f, xy0) # 接平面 z = tf(x,y) = dxy[0]x + dxy[1]y + c
########### 2変数関数f(x,y)とその指定点(x0,y0)での接平面を描画する
fig = plt.figure(figsize=(20, 6))
ax = Axes3D(fig)
ax.set_zlim3d(-1.0, 1.3)
# 格子サイズ
N = 1000 #@param {type:"integer"}
surf = draw_function2d(ax, f, N) # 曲面f(x,y)を描画する
fig.colorbar(surf) # Z値のカラーバーを描画する
xy0 = xy0.data # 勾配データgradは使わない
draw_tanjent_plane(ax, tf, xy0, xw=2, yw=2) # (x0, y0)での接平面を描画する
draw_grad_vector(ax, f, xy0, dxy) # (x0, y0)での勾配ベクトル(a,b)を描画する
# 接平面の方程式をタイトルに描画する
title = f'Tangent plane: $z = {get_equation_symbpl(dxy[0], dxy[1], c)}$'
ax.set_title(title)
# 軸ラベルを描画する
ax.set_xlabel('X axis'), ax.set_ylabel('Y axis'), ax.set_zlabel('$Z=F(x,y)$')
fig.show()
これを実行した結果が下の図です。
ax.set_zlim3d(-1.0, 1.3)
のところは、ターゲットの関数によって引数を調整した方がよいでしょう。