今回は前回つくったニューラルネットワークを、Chainerに変換していきます。
今回もお世話になったのは、このページです。ありがとうございます。
#Chainer.Variableの導入
前日つくったNumpyのニューラルネットワークのうち、重み付けのWの微分計算をする部分をVariableでつくります。
このとき、関数はFunctionをつかって定義します。
Variable:データを取り扱う部分。関数である変数。
Function:Variableを変数として扱う関数。
Functionを用いるときは、データ(変数)を必ずVariableとして定義する必要がある。
ここで、Variableを使うことによって、誤差逆伝搬の式がbackwordを使うだけですむ、というのが特徴かと思いました。
それぞれ比べます。
Numpy記述
x = X
y = Y
w1 = W1
w2 = W2
Chainer記述
x = Variable(X)
y = Variable(Y)
w1 = Variable(W1)
w2 = Variable(W2)
np.random.randnで定義した変数をVariableの変数としてセットします。
次に順方向計算です。
Numply記述
h = x.dot(w1)
h_r = np.maximum(h, 0)
y_p = h_r.dot(w2)
Chainer記述
h = F.matmul(x, w1)
h_r = F.relu(h)
y_p = F.matmul(h_r, w2)
内積のdotがmatmulになります。
np,maxmum(h,0)としていたLeRUがleruそのままとなって、記述がはっきりしました。
2乗の差分のルートである、誤差計算
Numpy記述
loss = np.square(y_p - y).sum() / y_size
print(loss)
Chainer記述
loss = F.mean_squared_error(y_p, y)
print(loss.data)
単純に、mean_squared_errorとなりました。
記述そのままです。
重みの微分計算
Numpy記述
grad_y_p = 2.0 * (y_p - y) / y_size
grad_w2 = h_r.T.dot(grad_y_p)
grad_h_r = grad_y_p.dot(w2.T)
grad_h = grad_h_r
grad_h[h < 0] = 0
grad_w1 = x.T.dot(grad_h)
Chainer記述
w1.zerograd()
w2.zerograd()
loss.backward()
こんなに記述が簡単になるのは、びっくりです。
3行ですみますので。
微分値より重み付けを更新するのは同じです。
では、全体の記述になります。
# -*- coding: utf-8 -*-
import numpy as np
import chainer.functions as F
from chainer import Variable
EPOCHS = 300
M = 2
N_I = 3
N_H = 2
N_O = 3
LEARNING_RATE = 1.0e-04
np.random.seed(1)
X = np.random.randn(M, N_I).astype(np.float32)
Y = np.random.randn(M, N_O).astype(np.float32)
W1 = np.random.randn(N_I, N_H).astype(np.float32)
W2 = np.random.randn(N_H, N_O).astype(np.float32)
def sample_2():
x = Variable(X)
y = Variable(Y)
w1 = Variable(W1)
w2 = Variable(W2)
for t in range(EPOCHS):
h = F.matmul(x, w1)
h_r = F.relu(h)
y_p = F.matmul(h_r, w2)
loss = F.mean_squared_error(y_p, y)
print(loss.data)
w1.zerograd()
w2.zerograd()
loss.backward()
w1.data -= LEARNING_RATE * w1.grad
w2.data -= LEARNING_RATE * w2.grad
if __name__ == '__main__':
sample_2()
300回計算させた結果、
lossが1.417から1.095まで減少したのを確認しました。
コード自体もシンプルになりましたし、また直接コードの表記から意味がわかります。
Numpyで作るとなると、式そのものを記述するため、それを導出する必要があります。
Chainerが直感的と言われる所以をしりました。
次は、Function,Variableだけでなく、ChainやOptimizerを使って記述してみます。